Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
734df790
Commit
734df790
authored
Jan 11, 2024
by
root
Browse files
Add MK-KN FP16 instances.
parent
f9f2cdf9
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
167 additions
and
24 deletions
+167
-24
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
+3
-1
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
...tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
+15
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
...ation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
..._xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
+125
-0
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
..._xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
+1
-1
profiler/src/profile_grouped_gemm_multiple_d_splitk.cpp
profiler/src/profile_grouped_gemm_multiple_d_splitk.cpp
+22
-22
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_tile_loop.hpp
View file @
734df790
...
...
@@ -956,7 +956,9 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
PipelineVer
<<
", "
<<
LoopSched
<<
">"
;
// clang-format on
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_multiple_d.hpp
View file @
734df790
...
...
@@ -30,6 +30,19 @@ void add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_nk_mn_irregu
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
...
...
@@ -74,6 +87,8 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_multi_d_splitk_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/CMakeLists.txt
View file @
734df790
add_instance_library
(
device_grouped_gemm_multiple_d_instance
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
)
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
0 → 100644
View file @
734df790
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/grouped_gemm_multiple_d/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
View file @
734df790
...
...
@@ -43,7 +43,7 @@ using device_ggemm_md_splitk_xdl_cshuffle_f16_f16_f16_mk_nk_mn_irregular_tile_in
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffle
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
F32
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
8
,
PipelineVersion
::
v1
>
,
...
...
profiler/src/profile_grouped_gemm_multiple_d_splitk.cpp
View file @
734df790
...
...
@@ -91,28 +91,28 @@ int profile_grouped_gemm_multiple_d_splitk(int argc, char* argv[])
const
auto
StrideCs
=
argToIntArray
(
argv
[
13
]);
const
int
kbatch
=
argc
==
15
?
std
::
stoi
(
argv
[
14
])
:
1
;
#ifdef CK_ENABLE_FP16
//
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
//
{
//
ck::profiler::profile_ggemm_multid_splitk<ck::half_t,
//
ck::half_t,
//
ck::half_t,
//
float,
//
ck::tensor_layout::gemm::RowMajor,
//
ck::tensor_layout::gemm::RowMajor,
//
ck::tensor_layout::gemm::RowMajor>(
do_verification,
// init_method
,
// do_log
,
// time_kernel
,
// Ms
,
// N
s,
// K
s,
// StrideA
s,
//
Stride
B
s,
//
Stride
C
s,
//
kbatch);
// }
// else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_ggemm_multid_splitk
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
M
s
,
N
s
,
K
s
,
Stride
A
s
,
Stride
B
s
,
StrideCs
,
kbatch
);
}
else
if
(
data_type
==
GemmDataType
::
F16_F16_F16
&&
layout
==
GemmMatrixLayout
::
MK_NK_MN
)
{
ck
::
profiler
::
profile_ggemm_multid_splitk
<
ck
::
half_t
,
ck
::
half_t
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment