Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ddba342e
Commit
ddba342e
authored
Dec 19, 2022
by
Po-Yen, Chen
Browse files
Use more generalizable pattern
parent
dc06e3fb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
122 additions
and
102 deletions
+122
-102
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
+52
-39
test/gemm/CMakeLists.txt
test/gemm/CMakeLists.txt
+10
-7
test/gemm/gemm_standalone_xdl_fp16.cpp
test/gemm/gemm_standalone_xdl_fp16.cpp
+19
-15
test/gemm/instance/gemm_f16_tn_instance.cpp
test/gemm/instance/gemm_f16_tn_instance.cpp
+41
-41
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp
View file @
ddba342e
...
@@ -104,66 +104,79 @@ struct GridwiseGemmPipeline_v2
...
@@ -104,66 +104,79 @@ struct GridwiseGemmPipeline_v2
#define IGLP_OPT_STRATEGY 1
#define IGLP_OPT_STRATEGY 1
#endif
#endif
#if defined(ENABLE_PIPELINE_V2_OPT)
#if !defined(LAYOUT_NN)
#if IGLP_OPT_STRATEGY == 1
#define LAYOUT_NN 0
// 8 MFMAs
#define LAYOUT_NT 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
2
,
0
);
// VMEM read
#define LAYOUT_TN 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
#define LAYOUT_TT 3
#endif
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
2
,
0
);
// VMEM read
#if defined(ENABLE_PIPELINE_V2_OPT)
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
#if GEMM_LAYOUT == LAYOUT_TN
#if IGLP_OPT_STRATEGY == 1
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
4
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0200
,
2
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
#elif IGLP_OPT_STRATEGY == 2
#elif IGLP_OPT_STRATEGY == 2
// 16 MFMAs
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
8
,
0
);
// MFMA
// cluster #1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
2
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
2
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0200
,
2
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0200
,
2
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0
0
20
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x00
0
8
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
4
,
0
);
// MFMA
#elif IGLP_OPT_STRATEGY == 3
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
16
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x00
8
,
1
,
0
);
//
MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0
20
0
,
2
,
0
);
//
DS write
// cluster #2
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x00
0
8
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0200
,
2
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0200
,
2
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
2
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
2
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
2
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
2
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x00
8
,
1
,
0
);
//
MFMA
__builtin_amdgcn_sched_group_barrier
(
0x00
20
,
1
,
0
);
//
VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x0008
,
6
,
0
);
// MFMA
#endif
#endif
#endif
#endif // GEMM_LAYOUT == 2
#endif // defined(ENABLE_PIPELINE_V2_OPT)
++
i
;
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
2
));
...
...
test/gemm/CMakeLists.txt
View file @
ddba342e
...
@@ -15,16 +15,19 @@ target_link_libraries(test_gemm_int8 PRIVATE utility)
...
@@ -15,16 +15,19 @@ target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
target_link_libraries
(
test_gemm_int8 PRIVATE device_gemm_instance
)
add_library
(
gemm_standalone_xdl_fp16_instances STATIC
add_library
(
gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp
#
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp
#
instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_tn_instance.cpp
instance/gemm_f16_tn_instance.cpp
instance/gemm_f16_tt_instance.cpp
#
instance/gemm_f16_tt_instance.cpp
)
)
set_source_files_properties
(
instance/gemm_f16_nn_instance.cpp PROPERTIES COMPILE_DEFINITIONS
"ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1"
)
set_source_files_properties
(
instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_OPTIONS
"--save-temps;-Wno-gnu-line-marker"
)
set_source_files_properties
(
instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_DEFINITIONS
"ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1"
)
set_source_files_properties
(
instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_DEFINITIONS
"ENABLE_PIPELINE_V2_OPT;GEMM_LAYOUT=LAYOUT_TN;IGLP_OPT_STRATEGY=1"
)
set_source_files_properties
(
instance/gemm_f16_nt_instance.cpp PROPERTIES COMPILE_DEFINITIONS
"ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1"
)
set_source_files_properties
(
instance/gemm_f16_tt_instance.cpp PROPERTIES COMPILE_DEFINITIONS
"ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1"
)
# set_source_files_properties(instance/gemm_f16_nn_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
# set_source_files_properties(instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
# set_source_files_properties(instance/gemm_f16_nt_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
# set_source_files_properties(instance/gemm_f16_tt_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
add_test_executable
(
test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp
)
add_test_executable
(
test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp
)
target_link_libraries
(
test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility
)
target_link_libraries
(
test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility
)
...
...
test/gemm/gemm_standalone_xdl_fp16.cpp
View file @
ddba342e
...
@@ -62,22 +62,26 @@ int main(int argc, char* argv[])
...
@@ -62,22 +62,26 @@ int main(int argc, char* argv[])
std
::
vector
<
std
::
tuple
<
GemmParams
,
LayoutConfig
,
OpFactoryFn
>>
problems
=
{
std
::
vector
<
std
::
tuple
<
GemmParams
,
LayoutConfig
,
OpFactoryFn
>>
problems
=
{
// clang-format off
// clang-format off
// 104 tiles
// 104 tiles
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x256
},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_256x128
},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x128
},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
false
,
true
},
add_gemm_f16_nn_128x64
},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_256x128
},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x128
},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
false
,
true
,
true
},
add_gemm_f16_nt_128x64
},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x256
},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_256x128
},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x128
},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
false
,
true
},
add_gemm_f16_tn_128x64
},
{
GemmParams
{
2048
,
3328
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x256
},
{
GemmParams
{
2048
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_256x128
},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{
GemmParams
{
1024
,
1664
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x128
},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{
GemmParams
{
1024
,
832
,
4096
},
LayoutConfig
{
true
,
true
,
true
},
add_gemm_f16_tt_128x64
},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// 110 tiles
// 110 tiles
// {GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
...
...
test/gemm/instance/gemm_f16_tn_instance.cpp
View file @
ddba342e
...
@@ -20,35 +20,35 @@ namespace instance {
...
@@ -20,35 +20,35 @@ namespace instance {
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
using
gemm_f16_tn_256x256
=
std
::
tuple
<
// using gemm_f16_tn_256x256 = std::tuple<
// clang-format off
// // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
256
,
32
,
8
,
8
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v2
>
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// clang-format on
// // clang-format on
>
;
// >;
//
// using gemm_f16_tn_256x128 = std::tuple<
// // clang-format off
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// // clang-format on
// >;
using
gemm_f16_tn_256x128
=
std
::
tuple
<
//using gemm_f16_tn_128x128 = std::tuple<
// clang-format off
// // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v2
>
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// clang-format on
// // clang-format on
>
;
// >;
using
gemm_f16_tn_128x128
=
std
::
tuple
<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
8
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
ck
::
LoopScheduler
::
Default
,
ck
::
PipelineVersion
::
v2
>
// clang-format on
>
;
using
gemm_f16_tn_128x64
=
std
::
tuple
<
using
gemm_f16_tn_128x64
=
std
::
tuple
<
// clang-format off
// clang-format off
...
@@ -60,20 +60,20 @@ using gemm_f16_tn_128x64 = std::tuple<
...
@@ -60,20 +60,20 @@ using gemm_f16_tn_128x64 = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
void
add_gemm_f16_tn_256x256
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
// void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
// {
add_device_operation_instances
(
instances
,
gemm_f16_tn_256x256
{});
// add_device_operation_instances(instances, gemm_f16_tn_256x256{});
}
// }
//
// void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_256x128{});
// }
void
add_gemm_f16_tn_256x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
// void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
// {
add_device_operation_instances
(
instances
,
gemm_f16_tn_256x128
{});
// add_device_operation_instances(instances, gemm_f16_tn_128x128{});
}
// }
void
add_gemm_f16_tn_128x128
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
add_device_operation_instances
(
instances
,
gemm_f16_tn_128x128
{});
}
void
add_gemm_f16_tn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
void
add_gemm_f16_tn_128x64
(
std
::
vector
<
std
::
unique_ptr
<
BaseOperator
>>&
instances
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment