Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bce6b139
Commit
bce6b139
authored
Mar 28, 2024
by
Jakub Piasecki
Browse files
tmp save
parent
5a78bce1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
213 additions
and
2 deletions
+213
-2
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
...ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
+25
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instance.cpp
...splitk_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instance.cpp
+98
-0
profiler/src/profile_grouped_gemm_two_stage.cpp
profiler/src/profile_grouped_gemm_two_stage.cpp
+26
-2
script/profile_ggemm_splitk.sh
script/profile_ggemm_splitk.sh
+63
-0
No files found.
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp
View file @
bce6b139
...
@@ -159,6 +159,19 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_insta
...
@@ -159,6 +159,19 @@ void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_insta
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances
(
void
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Row
,
...
@@ -271,6 +284,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
...
@@ -271,6 +284,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs
);
op_ptrs
);
}
}
}
}
#endif
#if defined(CK_ENABLE_BF16)
else
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
EDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances
(
op_ptrs
);
}
}
#endif
#endif
return
op_ptrs
;
return
op_ptrs
;
}
}
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt
View file @
bce6b139
...
@@ -10,5 +10,6 @@ add_instance_library(device_grouped_gemm_instance
...
@@ -10,5 +10,6 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instance.cpp
0 → 100644
View file @
bce6b139
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future
// a[m, k] * b[k, n] = e[m, n]
using
device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_generic_instances
=
std
::
tuple
<
// clang-format off
//#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
1
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
1
>
// clang-format on
>
;
using
device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances
=
std
::
tuple
<
// clang-format off
//#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
256
,
32
,
8
,
8
,
32
,
32
,
2
,
4
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
192
,
64
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
48
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
256
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
192
,
32
,
8
,
8
,
32
,
32
,
1
,
3
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
24
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
192
,
32
,
32
,
8
,
8
,
32
,
32
,
3
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
32
,
128
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
4
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
128
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
1
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
128
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
64
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
64
,
32
,
32
,
8
,
8
,
32
,
32
,
2
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
,
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
F32
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmMNKPadding
,
1
,
64
,
32
,
64
,
32
,
8
,
8
,
32
,
32
,
1
,
2
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
16
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
PipelineVersion
::
v1
>
// clang-format on
>
;
void
add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemm
<
Row
,
Row
,
Empty_Tuple
,
Row
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_instances
{});
add_device_operation_instances
(
instances
,
device_grouped_gemm_multiple_d_xdl_two_stage_bf16_bf16_bf16_mk_kn_mn_generic_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/src/profile_grouped_gemm_two_stage.cpp
View file @
bce6b139
...
@@ -17,7 +17,8 @@ enum struct GemmMatrixLayout
...
@@ -17,7 +17,8 @@ enum struct GemmMatrixLayout
enum
struct
GemmDataType
enum
struct
GemmDataType
{
{
F16_F16_F16
,
// 0
F16_F16_F16
,
// 0
BF16_INT8_BF16
// 1
BF16_INT8_BF16
,
// 1
BF16_BF16_BF16
// 2
};
};
#define OP_NAME "grouped_gemm_two_stage"
#define OP_NAME "grouped_gemm_two_stage"
...
@@ -47,7 +48,7 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
...
@@ -47,7 +48,7 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
{
{
std
::
cout
std
::
cout
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg1: tensor operation ("
OP_NAME
": "
OP_DESC
")
\n
"
<<
"arg2: data type (0: fp16; 1: bf16@int8)
\n
"
<<
"arg2: data type (0: fp16; 1: bf16@int8
; 2: bf16
)
\n
"
<<
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);
\n
"
<<
"arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);
\n
"
<<
"arg4: verification (0: no; 1: yes)
\n
"
<<
"arg4: verification (0: no; 1: yes)
\n
"
<<
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
<<
"arg5: initialization (0: no init; 1: integer value; 2: decimal value)
\n
"
...
@@ -145,6 +146,29 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
...
@@ -145,6 +146,29 @@ int profile_grouped_gemm_two_stage(int argc, char* argv[])
n_warmup
,
n_warmup
,
n_iter
);
n_iter
);
}
}
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
ck
::
profiler
::
profile_grouped_gemm_two_stage_impl
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
,
ck
::
tensor_layout
::
gemm
::
RowMajor
>
(
do_verification
,
init_method
,
do_log
,
time_kernel
,
Ms
,
Ns
,
Ks
,
StrideAs
,
StrideBs
,
StrideCs
,
kbatch
,
n_warmup
,
n_iter
);
}
else
else
{
{
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
throw
std
::
runtime_error
(
"wrong! this GEMM data_type & layout is not implemented"
);
...
...
script/profile_ggemm_splitk.sh
0 → 100644
View file @
bce6b139
#!/bin/bash
## GPU visibility
export
HIP_VISIBLE_DEVICES
=
0
DRIVER
=
"../build/bin/ckProfiler"
OP
=
"grouped_gemm_two_stage"
DATATYPE
=
"2"
#1: bf16int8 2:bf16
LAYOUT
=
"0"
VERIFY
=
"0"
INIT
=
"0"
LOG
=
"0"
TIME
=
"1"
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 2048 8192
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 2048 65536
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 8192 1024
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 8192 8192
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 8192 65536
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
8 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
32 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
40 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
48 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
56 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
64 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
72 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
80 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
88 4096 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
8 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
40 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
32 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
56 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
48 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
72 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
64 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
88 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
80 4096 11008
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
8 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
32 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
48 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
40 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
64 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
56 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
80 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
72 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
8 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
88 1228 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
24 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
16 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
40 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
32 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
56 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
48 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
72 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
64 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
88 2201 4096
-1
-1
-1
0
$DRIVER
$OP
$DATATYPE
$LAYOUT
$VERIFY
$INIT
$LOG
$TIME
80 2201 4096
-1
-1
-1
0
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment