Commit 68474822 authored by Adam Osewski's avatar Adam Osewski
Browse files

Minimize number of kernels to run.

Prepare GEMM standalone test for profiling with Omniperf.
Just run regular GEMM and wavelet GEMM at best perfomant configuration.
parent a80dce16
...@@ -15,11 +15,11 @@ target_link_libraries(test_gemm_int8 PRIVATE utility) ...@@ -15,11 +15,11 @@ target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
add_library(gemm_standalone_xdl_fp16_instances STATIC add_library(gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp # instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp # instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_tn_instance.cpp instance/gemm_f16_tn_instance.cpp
instance/gemm_wavelet_f16_tn_instance.cpp instance/gemm_wavelet_f16_tn_instance.cpp
instance/gemm_f16_tt_instance.cpp # instance/gemm_f16_tt_instance.cpp
) )
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp) add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
......
...@@ -6,10 +6,10 @@ ...@@ -6,10 +6,10 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "gemm_f16_nn_instance.hpp" // #include "gemm_f16_nn_instance.hpp"
#include "gemm_f16_nt_instance.hpp" // #include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp" #include "gemm_f16_tn_instance.hpp"
#include "gemm_f16_tt_instance.hpp" // #include "gemm_f16_tt_instance.hpp"
#include "gemm_wavelet_f16_tn_instance.hpp" #include "gemm_wavelet_f16_tn_instance.hpp"
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
...@@ -62,44 +62,48 @@ int main(int argc, char* argv[]) ...@@ -62,44 +62,48 @@ int main(int argc, char* argv[])
std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = { std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = {
// clang-format off // clang-format off
// Use following if you run it on MI200 GPU
// 104 tiles // 104 tiles
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128}, {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128}, {GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256}, // {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128}, // {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128}, // {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, // {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// 110 tiles // 110 tiles
{GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256}, // {GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128}, // {GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128}, // {GemmParams{1280, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64}, // {GemmParams{1280, 704, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256}, // {GemmParams{2560, 2816, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128}, // {GemmParams{2560, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128}, // {GemmParams{1280, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64}, // {GemmParams{1280, 704, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256}, // {GemmParams{2560, 2816, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128}, // {GemmParams{2560, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128}, // {GemmParams{1280, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64}, // {GemmParams{1280, 704, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256}, // {GemmParams{2560, 2816, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128}, // {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128}, // {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64}, // {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// clang-format on // clang-format on
}; };
......
...@@ -20,15 +20,15 @@ namespace instance { ...@@ -20,15 +20,15 @@ namespace instance {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using gemm_f16_tn_256x256 = std::tuple< // using gemm_f16_tn_256x256 = std::tuple<
// clang-format off // // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // // clang-format on
>; // >;
using gemm_f16_tn_256x128 = std::tuple< using gemm_f16_tn_256x128 = std::tuple<
// clang-format off // clang-format off
...@@ -40,45 +40,45 @@ using gemm_f16_tn_256x128 = std::tuple< ...@@ -40,45 +40,45 @@ using gemm_f16_tn_256x128 = std::tuple<
// clang-format on // clang-format on
>; >;
using gemm_f16_tn_128x128 = std::tuple< // using gemm_f16_tn_128x128 = std::tuple<
// clang-format off // // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // // clang-format on
>; // >;
using gemm_f16_tn_128x64 = std::tuple< // using gemm_f16_tn_128x64 = std::tuple<
// clang-format off // // clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // // clang-format on
>; // >;
void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances) // void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ // {
add_device_operation_instances(instances, gemm_f16_tn_256x256{}); // add_device_operation_instances(instances, gemm_f16_tn_256x256{});
} // }
void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances) void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ {
add_device_operation_instances(instances, gemm_f16_tn_256x128{}); add_device_operation_instances(instances, gemm_f16_tn_256x128{});
} }
void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances) // void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ // {
add_device_operation_instances(instances, gemm_f16_tn_128x128{}); // add_device_operation_instances(instances, gemm_f16_tn_128x128{});
} // }
void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances) // void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ // {
add_device_operation_instances(instances, gemm_f16_tn_128x64{}); // add_device_operation_instances(instances, gemm_f16_tn_128x64{});
} // }
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -27,13 +27,13 @@ using S = ck::Sequence<Is...>; ...@@ -27,13 +27,13 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances); // void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances); void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances); // void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances); // void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -30,15 +30,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -30,15 +30,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using gemm_f16_tn_256x256 = std::tuple< // using gemm_f16_tn_256x256 = std::tuple<
// clang-format off // // clang-format off
//##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // // clang-format on
>; // >;
using gemm_f16_tn_256x128 = std::tuple< using gemm_f16_tn_256x128 = std::tuple<
// clang-format off // clang-format off
...@@ -50,45 +50,45 @@ using gemm_f16_tn_256x128 = std::tuple< ...@@ -50,45 +50,45 @@ using gemm_f16_tn_256x128 = std::tuple<
// clang-format on // clang-format on
>; >;
using gemm_f16_tn_128x128 = std::tuple< // using gemm_f16_tn_128x128 = std::tuple<
// clang-format off // // clang-format off
//##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // // clang-format on
>; // >;
using gemm_f16_tn_128x64 = std::tuple< // using gemm_f16_tn_128x64 = std::tuple<
// clang-format off // // clang-format off
//##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| // //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> // DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // // clang-format on
>; // >;
void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances) // void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ // {
add_device_operation_instances(instances, gemm_f16_tn_256x256{}); // add_device_operation_instances(instances, gemm_f16_tn_256x256{});
} // }
void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances) void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ {
add_device_operation_instances(instances, gemm_f16_tn_256x128{}); add_device_operation_instances(instances, gemm_f16_tn_256x128{});
} }
void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances) // void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ // {
add_device_operation_instances(instances, gemm_f16_tn_128x128{}); // add_device_operation_instances(instances, gemm_f16_tn_128x128{});
} // }
void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances) // void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ // {
add_device_operation_instances(instances, gemm_f16_tn_128x64{}); // add_device_operation_instances(instances, gemm_f16_tn_128x64{});
} // }
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
...@@ -11,13 +11,13 @@ namespace tensor_operation { ...@@ -11,13 +11,13 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances); // void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances); void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances); // void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances); // void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances);
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment