Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
303d4594
Unverified
Commit
303d4594
authored
Apr 02, 2024
by
zjing14
Committed by
GitHub
Apr 02, 2024
Browse files
improved zeroing (#1221)
parent
5f2c89e8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
345 additions
and
169 deletions
+345
-169
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+2
-2
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+5
-5
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp
...le/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+150
-69
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+186
-91
No files found.
example/15_grouped_gemm/CMakeLists.txt
View file @
303d4594
...
@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
...
@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_int8
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_fixed_nk_
fp16_
fp8 grouped_gemm_xdl_fixed_nk_
fp16_
fp8.cpp
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8
)
add_example_dependencies
(
example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_
fp16_
fp8
)
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
303d4594
...
@@ -36,7 +36,7 @@ using BDataType = F16;
...
@@ -36,7 +36,7 @@ using BDataType = F16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F
32
;
using
EDataType
=
F
16
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Col
;
...
@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
...
@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
{
problem_size
.
Ms
.
push_back
(
256
+
256
*
i
);
problem_size
.
Ms
.
push_back
(
128
+
rand
()
%
128
);
problem_size
.
Ns
.
push_back
(
256
);
problem_size
.
Ns
.
push_back
(
1024
);
problem_size
.
Ks
.
push_back
(
1
28
);
problem_size
.
Ks
.
push_back
(
1
024
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp
→
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_
fp16_
fp8.cpp
View file @
303d4594
...
@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F8
;
using
BDataType
=
F8
;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F
32
;
using
CShuffleDataType
=
F
16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
EDataType
=
F16
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
303d4594
...
@@ -23,6 +23,7 @@ namespace device {
...
@@ -23,6 +23,7 @@ namespace device {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
bool
Zeroing
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
DsLayout
,
typename
DsLayout
,
...
@@ -106,8 +107,37 @@ __global__ void
...
@@ -106,8 +107,37 @@ __global__ void
const
auto
block_2_etile_map
=
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id_off
);
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id_off
);
if
constexpr
(
Zeroing
)
{
auto
barrier_count_finished
=
auto
barrier_count_finished
=
barrier_count
+
group_id
*
barrier_size_grp
+
id_local
%
mn_blocks
;
barrier_count
+
group_id
*
barrier_size_grp
+
id_local
%
mn_blocks
;
GridwiseGemm
::
template
RunWithZeroing
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
barrier_count_finished
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
KBatch
,
block_2_etile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
EGlobalMemoryDataOperation
,
...
@@ -120,7 +150,7 @@ __global__ void
...
@@ -120,7 +150,7 @@ __global__ void
p_ds_grid_
,
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
p_shared
,
barrier_count_finished
,
nullptr
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
...
@@ -133,6 +163,7 @@ __global__ void
...
@@ -133,6 +163,7 @@ __global__ void
StrideE
,
StrideE
,
KBatch
,
KBatch
,
block_2_etile_map
);
block_2_etile_map
);
}
id_off
+=
grid_size_grp
;
id_off
+=
grid_size_grp
;
id_local
+=
grid_size_grp
;
id_local
+=
grid_size_grp
;
...
@@ -193,8 +224,11 @@ template <typename ALayout,
...
@@ -193,8 +224,11 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
ComputeType
=
ADataType
,
typename
ComputeType
=
ADataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
typename
ALDSType
=
ComputeType
,
typename
BLDSType
=
ComputeType
>
struct
DeviceGroupedGemm_Xdl_Fixed_NK
:
public
DeviceGroupedGemmFixedNK
<
ALayout
,
struct
DeviceGroupedGemm_Xdl_Fixed_NK
:
public
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
...
@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
AComputeType
=
ComputeType
;
using
BComputeType
=
ComputeType
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_splitk_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_splitk_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
BDataType
,
BDataType
,
ComputeType
,
AComputeType
,
BComputeType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
DsDataType
,
DsDataType
,
...
@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CShuffleNXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
LoopSched
,
PipelineVer
,
ALDSType
,
BLDSType
>
;
template
<
typename
UnderlyingBlockToCTileMap
>
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMapMLoops
struct
OffsettedBlockToCTileMapMLoops
...
@@ -613,10 +654,49 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -613,10 +654,49 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
e_global_memory_operation_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
,
auto
e_global_memory_operation_
)
{
if
(
arg
.
k_batch_
==
1
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
false
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
DsDataType
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
e_global_memory_operation_
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
nullptr
,
arg
.
barrier_size_grp_
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp_
,
arg
.
k_batch_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
else
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_fixed_nk
<
GridwiseGemm
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GroupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
GemmSpec
,
true
,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
...
@@ -645,13 +725,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -645,13 +725,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
}
};
};
constexpr
auto
AtomicAdd
=
InMemoryDataOperationEnum
::
AtomicAdd
;
constexpr
auto
AtomicAdd
=
InMemoryDataOperationEnum
::
AtomicAdd
;
constexpr
auto
Set
=
InMemoryDataOperationEnum
::
Set
;
constexpr
auto
Set
=
InMemoryDataOperationEnum
::
Set
;
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
enforced
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
// in IsSupportedArgument function
//
enforced
in IsSupportedArgument function
if
constexpr
(
std
::
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
)
if
constexpr
(
std
::
is_same
<
ADataType
,
ck
::
bhalf_t
>::
value
)
{
{
if
(
has_main_k_block_loop
)
if
(
has_main_k_block_loop
)
...
@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
bool
supported
=
true
;
bool
supported
=
true
;
// If we use padding we do not support vector loads for dimensions not divisible by
vector
// If we use padding we do not support vector loads for dimensions not divisible by
// load size.
//
vector
load size.
if
constexpr
(
GemmSpec
!=
GemmSpecialization
::
Default
)
if
constexpr
(
GemmSpec
!=
GemmSpecialization
::
Default
)
{
{
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
layout,
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
// thus we have to adapt it to the {M,K} or {N,K} layout.
//
layout,
thus we have to adapt it to the {M,K} or {N,K} layout.
const
auto
a_raw_vector_dim
=
ABlockTransferSrcVectorDim
!=
1
?
1
:
0
;
const
auto
a_raw_vector_dim
=
ABlockTransferSrcVectorDim
!=
1
?
1
:
0
;
const
auto
b_raw_vector_dim
=
BBlockTransferSrcVectorDim
!=
1
?
1
:
0
;
const
auto
b_raw_vector_dim
=
BBlockTransferSrcVectorDim
!=
1
?
1
:
0
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
303d4594
...
@@ -31,7 +31,8 @@ namespace ck {
...
@@ -31,7 +31,8 @@ namespace ck {
// D0, D1, ... and E have the same layout
// D0, D1, ... and E have the same layout
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
ComputeType
,
typename
AComputeType
,
typename
BComputeType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
DsDataType
,
...
@@ -71,7 +72,9 @@ template <typename ADataType,
...
@@ -71,7 +72,9 @@ template <typename ADataType,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
,
typename
ALDSType
,
typename
BLDSType
>
struct
GridwiseGemmMultipleD_xdl_splitk_cshuffle
struct
GridwiseGemmMultipleD_xdl_splitk_cshuffle
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr
auto
c_block_size
=
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
(
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
return
math
::
max
(
a_block_space_size_aligned
*
sizeof
(
ALDSType
)
+
sizeof
(
Compute
Type
),
b_block_space_size_aligned
*
sizeof
(
BLDS
Type
),
c_block_size
*
sizeof
(
CShuffleDataType
));
c_block_size
*
sizeof
(
CShuffleDataType
));
}
}
...
@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
index_t
NumDTensor_
,
index_t
NumDTensor_
,
typename
DsDataType_
,
typename
DsDataType_
,
bool
Zeroing
,
typename
AGridDesc_KBatch_AK0_M_AK1
,
typename
AGridDesc_KBatch_AK0_M_AK1
,
typename
BGridDesc_KBatch_BK0_N_BK1
,
typename
BGridDesc_KBatch_BK0_N_BK1
,
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ADataType
,
Compute
Type
,
ALDS
Type
,
decltype
(
a_grid_desc_kbatch_ak0_m_ak1
),
decltype
(
a_grid_desc_kbatch_ak0_m_ak1
),
decltype
(
a_block_desc_kbatch_ak0_m_ak1
),
decltype
(
a_block_desc_kbatch_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
Compute
Type
,
BLDS
Type
,
decltype
(
b_grid_desc_kbatch_bk0_n_bk1
),
decltype
(
b_grid_desc_kbatch_bk0_n_bk1
),
decltype
(
b_block_desc_kbatch_bk0_n_bk1
),
decltype
(
b_block_desc_kbatch_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check
// sanity check
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
A
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
Compute
Type
,
ALDS
Type
,
Compute
Type
,
BLDS
Type
,
AccDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -611,9 +615,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -611,9 +615,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
KPack
,
KPack
,
LoopSched
>
();
LoopSched
,
AComputeType
,
BComputeType
>
();
#if 1
if
constexpr
(
Zeroing
)
{
if
(
block_work_idx
[
I0
]
==
0
)
if
(
block_work_idx
[
I0
]
==
0
)
{
{
const
index_t
nThreadSize
=
CDEShuffleBlockTransferScalarPerVector_NPerBlock
;
const
index_t
nThreadSize
=
CDEShuffleBlockTransferScalarPerVector_NPerBlock
;
...
@@ -659,14 +666,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -659,14 +666,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_buf
);
e_grid_buf
);
__syncthreads
();
__builtin_amdgcn_s_barrier
();
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
atomicAdd
(
barrier_count_finished
,
1
);
atomicAdd
(
barrier_count_finished
,
1
);
}
}
}
}
#endif
}
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Compute
Type
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ALDS
Type
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Compute
Type
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
BLDS
Type
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
...
@@ -710,13 +717,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -710,13 +717,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
if
constexpr
(
Zeroing
)
{
{
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
while
(
__atomic_load_n
(
barrier_count_finished
,
__ATOMIC_RELAXED
)
==
0
)
{}
while
(
__atomic_load_n
(
barrier_count_finished
,
__ATOMIC_RELAXED
)
==
0
)
{}
}
}
__builtin_amdgcn_s_barrier
();
__syncthreads
();
}
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
...
@@ -951,6 +960,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -951,6 +960,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
}
});
});
if
constexpr
(
Zeroing
)
{
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
index_t
k_id_finished_t
=
atomicAdd
(
barrier_count_finished
,
1
);
index_t
k_id_finished_t
=
atomicAdd
(
barrier_count_finished
,
1
);
...
@@ -962,6 +973,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -962,6 +973,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
}
}
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
...
@@ -971,7 +983,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -971,7 +983,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
typename
DsLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ELayout
,
typename
Block2ETileMap
>
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid_
,
__device__
static
void
Run
WithZeroing
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_e_grid_
,
...
@@ -1035,7 +1047,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -1035,7 +1047,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
if
(
kbatch_id
==
KBatch
-
1
)
if
(
kbatch_id
==
KBatch
-
1
)
{
{
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
NumDTensor
,
DsDataType
>
(
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
NumDTensor
,
DsDataType
,
true
>
(
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_ds_grid
,
p_ds_grid
,
...
@@ -1054,7 +1066,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -1054,7 +1066,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
}
}
else
else
{
{
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
0
,
Tuple
<>>
(
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
0
,
Tuple
<>
,
true
>
(
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_ds_grid
,
p_ds_grid
,
...
@@ -1072,6 +1084,89 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -1072,6 +1084,89 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
block_2_etile_map
);
block_2_etile_map
);
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
EGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid_
,
const
void
*
__restrict__
p_b_grid_
,
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
uint32_t
*
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
index_t
M
,
const
index_t
N
,
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>
({},
{},
{}))
>
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_m_n
(
j
)
=
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>
(
M
,
N
,
StrideDs
[
j
]);
});
const
auto
e_grid_desc_m_n
=
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>
(
M
,
N
,
StrideE
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
<
BLayout
,
GemmSpec
>
(
K
,
N
,
StrideB
,
KBatch
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
ds_grid_desc_mblock_mperblock_nblock_nperblock
(
j
)
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
j
]);
});
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
NumDTensor
,
DsDataType
,
false
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared
,
nullptr
,
KBatch
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_kbatch_ak0_m_ak1
,
b_grid_desc_kbatch_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
}
};
};
}
// namespace ck
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment