Commit ddba342e authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Use more generalizable pattern

parent dc06e3fb
......@@ -104,66 +104,79 @@ struct GridwiseGemmPipeline_v2
#define IGLP_OPT_STRATEGY 1
#endif
#if defined(ENABLE_PIPELINE_V2_OPT)
#if IGLP_OPT_STRATEGY == 1
// 8 MFMAs
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#if !defined(LAYOUT_NN)
#define LAYOUT_NN 0
#define LAYOUT_NT 1
#define LAYOUT_TN 2
#define LAYOUT_TT 3
#endif
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#if defined(ENABLE_PIPELINE_V2_OPT)
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#if GEMM_LAYOUT == LAYOUT_TN
#if IGLP_OPT_STRATEGY == 1
__builtin_amdgcn_sched_group_barrier(0x0008, 4, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0200, 2, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
#elif IGLP_OPT_STRATEGY == 2
// 16 MFMAs
// cluster #1
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0008, 8, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 2, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0200, 2, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0200, 2, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0008, 4, 0); // MFMA
#elif IGLP_OPT_STRATEGY == 3
__builtin_amdgcn_sched_group_barrier(0x0008, 16, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
// cluster #2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0200, 2, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0200, 2, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x0008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0200, 2, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x0008, 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x0020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#endif
__builtin_amdgcn_sched_group_barrier(0x0008, 6, 0); // MFMA
#endif
#endif // GEMM_LAYOUT == 2
#endif // defined(ENABLE_PIPELINE_V2_OPT)
++i;
} while(i < (num_loop - 2));
......
......@@ -15,16 +15,19 @@ target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
add_library(gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp
# instance/gemm_f16_nn_instance.cpp
# instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_tn_instance.cpp
instance/gemm_f16_tt_instance.cpp
# instance/gemm_f16_tt_instance.cpp
)
set_source_files_properties(instance/gemm_f16_nn_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
set_source_files_properties(instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
set_source_files_properties(instance/gemm_f16_nt_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
set_source_files_properties(instance/gemm_f16_tt_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
set_source_files_properties(instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_OPTIONS "--save-temps;-Wno-gnu-line-marker")
set_source_files_properties(instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;GEMM_LAYOUT=LAYOUT_TN;IGLP_OPT_STRATEGY=1")
# set_source_files_properties(instance/gemm_f16_nn_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
# set_source_files_properties(instance/gemm_f16_tn_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
# set_source_files_properties(instance/gemm_f16_nt_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
# set_source_files_properties(instance/gemm_f16_tt_instance.cpp PROPERTIES COMPILE_DEFINITIONS "ENABLE_PIPELINE_V2_OPT;IGLP_OPT_STRATEGY=1")
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
......
......@@ -62,22 +62,26 @@ int main(int argc, char* argv[])
std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = {
// clang-format off
// 104 tiles
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// 110 tiles
// {GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
......
......@@ -20,35 +20,35 @@ namespace instance {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using gemm_f16_tn_256x256 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// clang-format on
>;
// using gemm_f16_tn_256x256 = std::tuple<
// // clang-format off
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// // clang-format on
// >;
//
// using gemm_f16_tn_256x128 = std::tuple<
// // clang-format off
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// // clang-format on
// >;
using gemm_f16_tn_256x128 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// clang-format on
>;
using gemm_f16_tn_128x128 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// clang-format on
>;
//using gemm_f16_tn_128x128 = std::tuple<
// // clang-format off
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, ck::LoopScheduler::Default, ck::PipelineVersion::v2>
// // clang-format on
// >;
using gemm_f16_tn_128x64 = std::tuple<
// clang-format off
......@@ -60,20 +60,20 @@ using gemm_f16_tn_128x64 = std::tuple<
// clang-format on
>;
void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_256x256{});
}
// void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_256x256{});
// }
//
// void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_256x128{});
// }
void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_256x128{});
}
void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_128x128{});
}
// void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_128x128{});
// }
void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment