Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eaa68635
Commit
eaa68635
authored
Jun 04, 2024
by
Adam Osewski
Browse files
Fix RunWrite.
parent
125a39d1
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
177 additions
and
150 deletions
+177
-150
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
.../grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
+176
-149
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
eaa68635
...
...
@@ -101,7 +101,7 @@ __global__ void
index_t
gemm_tile_id_end
=
grid_size_grp
;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
typename
GridwiseGemm
::
AccType
,
GridwiseGemm
::
GetM
Per
Xdl
()
*
GridwiseGemm
::
GetN
Per
Xdl
(),
GridwiseGemm
::
GetMXdl
PerWave
()
*
GridwiseGemm
::
GetNXdl
PerWave
(),
GridwiseGemm
::
GetCThreadBufferVectorSize
(),
true
>
results_buffer
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle_v2.hpp
View file @
eaa68635
...
...
@@ -308,8 +308,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// M0 - MBlock
// M1 - MPerBlock
// N0 - NBlock
// N1 - N
VecPerThread
// N2 - NVecSize
// N1 - N
repeats
// N2 - NVecSize
* cluster length
template
<
typename
EGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_M0M1_N0N1N2
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
...
...
@@ -330,33 +330,18 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
workspace_thread_desc_m0m1_n0n1n2
=
MakeReductionThreadDesc_M0M1_N0N1N2
();
// # of threads in NDim * vector load size * # repeats per thread
constexpr
auto
NPerBlockPadded
=
cluster_length_reduce
.
At
(
I2
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
);
constexpr
auto
NPerBlockPad
=
NPerBlockPadded
-
Number
<
NPerBlock
>
{};
const
auto
e_grid_desc_m0m1_n0n1pad
=
transform_tensor_descriptor
(
const
auto
e_grid_desc_m0m1_n0n1n2
=
transform_tensor_descriptor
(
e_grid_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_pass_through_transform
(
make_tuple
(
make_pass_through_transform
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
)),
make_pass_through_transform
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I1
)),
make_pass_through_transform
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
)),
make_right_pad_transform
(
Number
<
NPerBlock
>
{},
NPerBlockPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
e_grid_desc_m0m1_n0n1n2
=
transform_tensor_descriptor
(
e_grid_desc_m0m1_n0n1pad
,
make_tuple
(
make_pass_through_transform
(
e_grid_desc_m0m1_n0n1pad
.
GetLength
(
I0
)),
make_pass_through_transform
(
e_grid_desc_m0m1_n0n1pad
.
GetLength
(
I1
)),
make_pass_through_transform
(
e_grid_desc_m0m1_n0n1pad
.
GetLength
(
I2
)),
make_unmerge_transform
(
make_tuple
(
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
)
*
cluster_length_reduce
.
At
(
I2
),
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
)))),
make_unmerge_transform
(
make_tuple
(
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
),
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
)
*
cluster_length_reduce
.
At
(
I2
)))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}));
...
...
@@ -436,8 +421,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
KBatch
*
KPerBlock
;
if
(
!
(
K
%
K_t
==
0
))
if
(
!
(
K
%
KPerBlock
==
0
))
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
...
...
@@ -540,6 +524,33 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
}
}
const
auto
k_batch_size
=
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KBatch
;
if
(
k_batch_size
<
KPerBlock
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"The k-batch size ("
<<
k_batch_size
<<
") value is less than KPerBlock!
\n
"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
if
(
k_batch_size
%
KPerBlock
!=
0
)
{
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"The k-batch size ("
<<
k_batch_size
<<
") value is not a multiple of KPerBlock!
\n
"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
<<
std
::
endl
;
}
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
@@ -624,8 +635,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
__device__
__host__
static
constexpr
auto
GetNPerBlock
()
{
return
NPerBlock
;
}
__device__
__host__
static
constexpr
auto
GetM
Per
Xdl
()
{
return
M
Per
Xdl
;
}
__device__
__host__
static
constexpr
auto
GetN
Per
Xdl
()
{
return
N
Per
Xdl
;
}
__device__
__host__
static
constexpr
auto
GetMXdl
PerWave
()
{
return
MXdl
PerWave
;
}
__device__
__host__
static
constexpr
auto
GetNXdl
PerWave
()
{
return
NXdl
PerWave
;
}
__device__
static
constexpr
auto
GetCThreadBufferVectorSize
()
{
...
...
@@ -646,7 +657,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
return
BlockwiseGemmT
::
xdlops_gemm
.
GetRegSizePerXdlops
();
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
,
typename
CThreadBuf
>
template
<
typename
Block2ETileMap
,
typename
CThreadBuf
>
__device__
static
void
RunGEMM
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
void
*
__restrict__
p_shared
,
...
...
@@ -656,7 +667,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
,
const
index_t
k_tiles
)
const
index_t
k_batch
,
const
index_t
next_k_tiles
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
@@ -726,16 +738,11 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
true
,
NumGemmKPrefetchStage
>
;
const
index_t
ak0_start_idx
=
kbatch_id
*
AK0PerBlock
;
const
index_t
bk0_start_idx
=
kbatch_id
*
BK0PerBlock
;
if
(
blockIdx
.
x
<
4
&&
ck
::
debug
::
is_thread_local_1d_id_idx
<
0
>
())
{
printf
(
"[RunGEMM] bid: %d, ak0_start_idx: %d, bk0_start_idx: %d
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
ak0_start_idx
,
bk0_start_idx
);
}
const
index_t
num_k_tiles_per_batch
=
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
KPerBlock
*
k_batch
);
const
index_t
ak0_start_idx
=
kbatch_id
*
num_k_tiles_per_batch
*
AK0PerBlock
.
value
;
const
index_t
bk0_start_idx
=
kbatch_id
*
num_k_tiles_per_batch
*
BK0PerBlock
.
value
;
// A matrix blockwise copy
auto
a_blockwise_copy
=
...
...
@@ -777,25 +784,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
static_cast
<
ComputeType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K
PerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K
PerBlock
/
BK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
AK0
PerBlock
.
value
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
BK0
PerBlock
.
value
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
();
// TODO: what if AK1 != BK1 ???
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
k_tiles
);
// __builtin_amdgcn_readfirstlane((a_grid_desc_ak0_m_ak1.GetLength(I1) *
// a_grid_desc_ak0_m_ak1.GetLength(I3)) /
// KPerBlock);
if
(
blockIdx
.
x
<
4
&&
ck
::
debug
::
is_thread_local_1d_id_idx
<
0
>
())
{
printf
(
"[RunGEMM] bid: %d, num_k_block_main_loop %d
\n
"
,
static_cast
<
index_t
>
(
blockIdx
.
x
),
num_k_block_main_loop
);
}
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipe
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
next_k_tiles
*
num_k_tiles_per_batch
);
const
bool
has_k_block_main_loop
=
gridwise_gemm_pipeline
.
CalculateHasMainLoop
(
num_k_block_main_loop
);
bool
clear_c_thread_buf
=
true
;
...
...
@@ -813,7 +810,9 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
KPack
,
LoopSched
>
();
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
if
(
has_k_block_main_loop
)
{
gridwise_gemm_pipeline
.
template
Run
<
true
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -830,8 +829,28 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
num_k_block_main_loop
,
clear_c_thread_buf
);
}
else
{
gridwise_gemm_pipeline
.
template
Run
<
false
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
,
clear_c_thread_buf
);
}
}
template
<
bool
HasMainKBlockLoop
,
typename
Block2ETileMap
,
typename
CThreadBuf
>
template
<
typename
Block2ETileMap
,
typename
CThreadBuf
>
__device__
static
void
RunGEMM
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
void
*
__restrict__
p_shared
,
...
...
@@ -840,21 +859,21 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
S
tride
A
,
const
index_t
S
tride
B
,
const
index_t
KB
atch
,
const
index_t
s
tride
_a
,
const
index_t
s
tride
_b
,
const
index_t
k_b
atch
,
const
Block2ETileMap
&
block_2_etile_map
,
CThreadBuf
&
c_thread_buf
,
const
index_t
k_tiles
)
const
index_t
next_
k_tiles
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
S
tride
A
,
KB
atch
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
S
tride
B
,
KB
atch
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
M
,
K
,
s
tride
_a
,
k_b
atch
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
K
,
N
,
s
tride
_b
,
k_b
atch
);
RunGEMM
<
HasMainKBlockLoop
>
(
p_a_grid
,
RunGEMM
(
p_a_grid
,
p_b_grid
,
p_shared
,
a_element_op
,
...
...
@@ -863,7 +882,8 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
b_grid_desc_bk0_n_bk1
,
block_2_etile_map
,
c_thread_buf
,
k_tiles
);
k_batch
,
next_k_tiles
);
}
template
<
typename
CThreadBuf
>
...
...
@@ -1098,24 +1118,24 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
I3
)
>
{};
}
// M0 - 1
// M1 - M elements per thread
// N0 - 1
// N1 - N repeats per thread
// N2 - Vector load/store size
__device__
static
constexpr
auto
MakeReductionThreadDesc_M0M1_N0N1N2
()
{
constexpr
auto
cluster_lengths
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
N1_elems
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_lengths
.
At
(
I2
));
static_assert
(
N1_elems
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
,
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1_elems have to be a multiple of "
"CDEShuffleBlockTransferScalarPerVector_NPerBlock!"
);
constexpr
auto
N2
=
Number
<
CDEShuffleBlockTransferScalarPerVector_NPerBlock
>
{};
constexpr
auto
N1
=
math
::
integer_divide_ceil
(
N1_elems
,
N2
);
constexpr
auto
N1
=
Number
<
NPerBlock
>
{}
/
(
Number
<
cluster_lengths
.
At
(
I2
)
>
{}
*
N2
);
constexpr
auto
M1
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_lengths
.
At
(
I0
));
static_assert
(
Number
<
M1
>
{}
*
cluster_lengths
.
At
(
I0
)
>
=
Number
<
MPerBlock
>
{},
Number
<
M1
>
{}
*
cluster_lengths
.
At
(
I0
)
=
=
Number
<
MPerBlock
>
{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! M1 * cluster_length[0] have to be grater "
"or equal to MPerBlock."
);
static_assert
(
Number
<
N1
>
{}
*
Number
<
N2
>
{}
*
cluster_lengths
.
At
(
I2
)
>
=
Number
<
NPerBlock
>
{},
static_assert
(
Number
<
N1
>
{}
*
Number
<
N2
>
{}
*
cluster_lengths
.
At
(
I2
)
=
=
Number
<
NPerBlock
>
{},
"Invalid ReductionThreadDesc M0M1_N0N1N2! N1 * N2 * cluster_length[2] have "
"to be grater or equal to NPerBlock."
);
...
...
@@ -1129,6 +1149,15 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
{
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
reduce_cluster_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
>=
reduce_cluster_desc
.
GetElementSize
(),
"Error! ThisThreadBlock::GetNumOfThread() too small"
);
if
(
ThisThreadBlock
::
GetThreadId
()
>=
reduce_cluster_desc
.
GetElementSize
())
{
return
;
}
const
auto
reduce_thread_cluster_idx
=
reduce_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
...
...
@@ -1139,27 +1168,12 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
const
auto
workspace_grid_desc_m0m1_n0n1
=
MakeWorkspaceGridDesc_GridSize_MPerBlock_I1_NPerBlock
(
get_grid_size
());
// # of threads in NDim * vector load size * # repeats per thread
constexpr
auto
NPerBlockPadded
=
cluster_length_reduce
.
At
(
I2
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
)
*
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
);
constexpr
auto
NPerBlockPad
=
NPerBlockPadded
-
Number
<
NPerBlock
>
{};
const
auto
workspace_grid_desc_m0m1_n0n1pad
=
transform_tensor_descriptor
(
const
auto
workspace_grid_desc_m0m1_n0n1n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0m1_n0n1
,
make_tuple
(
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1
.
GetLength
(
I0
)),
make_tuple
(
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1
.
GetLength
(
I0
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1
.
GetLength
(
I1
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1
.
GetLength
(
I2
)),
make_right_pad_transform
(
Number
<
NPerBlock
>
{},
NPerBlockPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
workspace_grid_desc_m0m1_n0n1n2
=
transform_tensor_descriptor
(
workspace_grid_desc_m0m1_n0n1pad
,
make_tuple
(
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1pad
.
GetLength
(
I0
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1pad
.
GetLength
(
I1
)),
make_pass_through_transform
(
workspace_grid_desc_m0m1_n0n1pad
.
GetLength
(
I2
)),
make_unmerge_transform
(
make_tuple
(
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
),
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I4
)
*
cluster_length_reduce
.
At
(
I2
)))),
...
...
@@ -1255,7 +1269,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
__device__
static
void
RunWrite
(
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
/* void* __restrict__ p_shared, */
const
AccumulationBuffer
&
acc_buff
,
AccumulationBuffer
&
acc_buff
,
const
index_t
M
,
const
index_t
N
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
...
...
@@ -1301,9 +1315,6 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
},
Number
<
NumDTensor
>
{});
auto
aux_vgpr_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EDataType
,
ScalarPerVector
,
true
>
{};
constexpr
auto
d_vgpr_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
I1
,
Number
<
ScalarPerVector
>
{}));
...
...
@@ -1312,6 +1323,14 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction_M0_N0N1
();
constexpr
auto
reduce_cluster_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
>=
reduce_cluster_desc
.
GetElementSize
(),
"Error! ThisThreadBlock::GetNumOfThread() too small"
);
if
(
ThisThreadBlock
::
GetThreadId
()
>=
reduce_cluster_desc
.
GetElementSize
())
{
return
;
}
const
auto
reduce_thread_cluster_idx
=
reduce_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
...
...
@@ -1344,6 +1363,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
},
Number
<
NumDTensor
>
{});
// Each thread writes consecutive M rows and strided N columns
auto
e_grid_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
EDataType
,
EDataType
,
...
...
@@ -1368,7 +1388,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
constexpr
auto
MIter
=
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I1
);
constexpr
auto
NIter
=
workspace_thread_desc_m0m1_n0n1n2
.
GetLength
(
I3
);
constexpr
auto
n1_step
=
cluster_length_reduce
.
At
(
I2
)
;
constexpr
auto
n1_step
=
I1
;
constexpr
auto
d_grid_M1_fwd_step
=
make_multi_index
(
I0
,
I1
,
I0
,
I0
,
I0
);
constexpr
auto
d_grid_N1_fwd_step
=
make_multi_index
(
I0
,
I0
,
I0
,
n1_step
,
I0
);
...
...
@@ -1410,7 +1430,7 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// src_data_refs[I0],
// src_data_refs[I1],
// ...)
unpack2
(
cde_element_op
,
tie
(
a
ux_vgpr_buf
(
I
)),
src_data_refs
);
unpack2
(
cde_element_op
,
tie
(
a
cc_buff
(
acc_buf_offset
+
I
)),
src_data_refs
);
});
// if(is_thread_local_1d_id_idx<0, 1, 8, 39>())
...
...
@@ -1429,18 +1449,22 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
// }
e_grid_store
.
Run
(
workspace_thread_desc_m0m1_n0n1n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
),
aux_vgpr_buf
,
make_tuple
(
I0
,
m_idx
,
I0
,
n_idx
,
I0
),
// aux_vgpr_buf,
acc_buff
,
e_grid_desc_m0m1_n0n1n2
,
e_grid_buf
);
if
constexpr
(
NIter
!=
1
)
{
if
constexpr
(
n_idx
!=
(
NIter
-
1
))
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
d_idx
)
{
ds_grid_load
(
d_idx
).
MoveSrcSliceWindow
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
),
d_grid_N1_fwd_step
);
});
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_N1_fwd_step
);
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_N1_fwd_step
);
}
else
{
...
...
@@ -1448,16 +1472,19 @@ class GridwiseGemmMultipleD_xdl_splitk_cshuffle_v2
ds_grid_load
(
d_idx
).
MoveSrcSliceWindow
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
),
d_grid_N1_bwd_step
);
});
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_N1_bwd_step
);
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_N1_bwd_step
);
}
}
});
// NIter
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
d_idx
)
{
ds_grid_load
(
d_idx
).
MoveSrcSliceWindow
(
ds_grid_desc_m0m1_n0n1n2
(
d_idx
),
d_grid_M1_fwd_step
);
});
e_grid_store
.
MoveDstSliceWindow
(
e_grid_desc_m0m1_n0n1n2
,
d_grid_M1_fwd_step
);
}
});
// MIter
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment