Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
134fc2e7
Commit
134fc2e7
authored
Dec 19, 2023
by
Adam Osewski
Browse files
Fix StorePartials.
Pass pointer to whole workspace not the shifted one.
parent
88a4fbfb
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
26 deletions
+68
-26
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
..._grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+1
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+66
-21
No files found.
example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp
View file @
134fc2e7
...
@@ -319,6 +319,7 @@ int main(int argc, char* argv[])
...
@@ -319,6 +319,7 @@ int main(int argc, char* argv[])
if
(
argc
<
11
)
if
(
argc
<
11
)
{
{
std
::
vector
<
ck
::
index_t
>
Ms
{
64
,
127
,
255
,
129
,
260
,
190
,
77
};
std
::
vector
<
ck
::
index_t
>
Ms
{
64
,
127
,
255
,
129
,
260
,
190
,
77
};
problem_size
.
group_count
=
Ms
.
size
();
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
134fc2e7
...
@@ -162,11 +162,7 @@ __global__ void
...
@@ -162,11 +162,7 @@ __global__ void
// if (changed group_id || next [M,N] tile)
// if (changed group_id || next [M,N] tile)
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
if
(
!
b2c_tile_map
.
IsFirstKSplitBlock
())
{
{
void
*
__restrict__
p_block_workspace
=
reinterpret_cast
<
void
*
__restrict__
>
(
gridwise_gemm
.
StorePartials
(
p_workspace
);
reinterpret_cast
<
char
*>
(
p_workspace
)
+
blockIdx
.
x
*
GridwiseGemm
::
GetMPerBlock
()
*
GridwiseGemm
::
GetNPerBlock
()
*
sizeof
(
typename
GridwiseGemm
::
AccType
));
gridwise_gemm
.
StorePartials
(
p_block_workspace
);
}
}
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
work_scheduler
.
FlagFinished
(
k_batch
,
output_tile_idx
,
output_tile_idx_offset
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
134fc2e7
...
@@ -814,26 +814,71 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -814,26 +814,71 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
{
const
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
const
auto
&
c_thread_buf
=
blockwise_gemm_
.
GetCThreadBuffer
();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
// M0 = grid_size
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
// N0 = 1
// M1 = MPerBlock
// N1 = NPerBlock
const
auto
workspace_grid_desc_m0_n0_m1_n1
=
MakeWorkspaceGridDesc_GridSize_I1_MPerBlock_NPerBlock
(
get_grid_size
());
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
// M0 = grid_size -> MRepeats (MXdlPerWave)
// N0 = 1 -> NRepeats (NXdlPerWave)
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1
,
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
make_pass_through_transform
(
w_grid_n0
),
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
4
,
6
,
7
,
8
>
{},
Sequence
<
3
,
5
,
9
>
{}));
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
make_merge_transform
(
make_tuple
(
w_grid_m0
,
M0
)),
// MRepeats (grid)
make_merge_transform
(
make_tuple
(
w_grid_n0
,
N0
)),
// NRepeats (grid)
make_pass_through_transform
(
M1
),
// MWave
make_pass_through_transform
(
N1
),
// NWave
make_pass_through_transform
(
M2
),
// mfma_instr.num_groups_per_blk
make_pass_through_transform
(
M3
),
// mfma_instr.num_input_blks
make_pass_through_transform
(
M4
),
// mfma_instr.group_size
make_pass_through_transform
(
N2
)),
// mfma_instr.num_threads_per_blk
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{},
Sequence
<
8
>
{},
Sequence
<
9
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
p_workspace_grid
=
reinterpret_cast
<
AccDataType
*>
(
p_workspace
);
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
w_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
p_workspace_grid
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
BlockwiseGemmT
::
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
@@ -869,14 +914,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -869,14 +914,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// DimAccessOrder
// N -> then M dims
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// DimAccessOrder
7
,
// DstVectorDim,
7
,
// DstVectorDim,
1
,
// DstScalarPerVector
1
,
// DstScalarPerVector
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
// DstResetCoordinateAfterRun
true
>
{
// DstResetCoordinateAfterRun
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_block_idx
[
I0
]
,
make_multi_index
(
(
static_cast
<
index_t
>
(
blockIdx
.
x
))
*
MXdlPerWave
,
n_thread_data_on_block_idx
[
I0
],
n_thread_data_on_block_idx
[
I0
],
m_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
...
@@ -916,7 +962,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -916,7 +962,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_m0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I0
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
const
auto
w_grid_n0
=
workspace_grid_desc_m0_n0_m1_n1
.
GetLength
(
I1
);
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
BlockwiseGemmT
::
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
...
@@ -929,8 +974,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -929,8 +974,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
// M0 = grid_size -> MRepeats
// M0 = grid_size -> MRepeats
(MXdlPerWave)
// N0 = 1 -> NRepeats
// N0 = 1 -> NRepeats
(NXdlPerWave)
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
const
auto
workspace_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
=
transform_tensor_descriptor
(
workspace_grid_desc_m0_n0_m1_n1
,
workspace_grid_desc_m0_n0_m1_n1
,
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
make_tuple
(
make_pass_through_transform
(
w_grid_m0
),
...
@@ -1003,7 +1048,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
...
@@ -1003,7 +1048,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// SrcDesc,
decltype
(
workspace_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// SrcDesc,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// DstDesc,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
// DstDesc,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLengths
()),
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
// DimAccessOrder,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// DimAccessOrder,
7
,
// SrcVectorDim,
7
,
// SrcVectorDim,
1
,
// SrcScalarPerVector,
1
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
1
,
// SrcScalarStrideInVector,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment