Commit 68474822 authored by Adam Osewski's avatar Adam Osewski
Browse files

Minimize number of kernels to run.

Prepare GEMM standalone test for profiling with Omniperf.
Just run regular GEMM and wavelet GEMM at best perfomant configuration.
parent a80dce16
......@@ -15,11 +15,11 @@ target_link_libraries(test_gemm_int8 PRIVATE utility)
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
add_library(gemm_standalone_xdl_fp16_instances STATIC
instance/gemm_f16_nn_instance.cpp
instance/gemm_f16_nt_instance.cpp
# instance/gemm_f16_nn_instance.cpp
# instance/gemm_f16_nt_instance.cpp
instance/gemm_f16_tn_instance.cpp
instance/gemm_wavelet_f16_tn_instance.cpp
instance/gemm_f16_tt_instance.cpp
# instance/gemm_f16_tt_instance.cpp
)
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
......
......@@ -6,10 +6,10 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#include "gemm_f16_nn_instance.hpp"
#include "gemm_f16_nt_instance.hpp"
// #include "gemm_f16_nn_instance.hpp"
// #include "gemm_f16_nt_instance.hpp"
#include "gemm_f16_tn_instance.hpp"
#include "gemm_f16_tt_instance.hpp"
// #include "gemm_f16_tt_instance.hpp"
#include "gemm_wavelet_f16_tn_instance.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
......@@ -62,44 +62,48 @@ int main(int argc, char* argv[])
std::vector<std::tuple<GemmParams, LayoutConfig, OpFactoryFn>> problems = {
// clang-format off
// Use following if you run it on MI200 GPU
// 104 tiles
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64},
{GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, false, true}, add_gemm_wavelet_f16_tn_128x64},
// {GemmParams{2048, 3328, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
// {GemmParams{2048, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
// {GemmParams{1024, 1664, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1024, 832, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// 110 tiles
{GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
{GemmParams{2560, 2816, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
{GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
{GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
{GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// {GemmParams{2560, 2816, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x256},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_256x128},
// {GemmParams{1280, 1408, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x128},
// {GemmParams{1280, 704, 4096}, LayoutConfig{false, false, true}, add_gemm_f16_nn_128x64},
// {GemmParams{2560, 2816, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x256},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_256x128},
// {GemmParams{1280, 1408, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x128},
// {GemmParams{1280, 704, 4096}, LayoutConfig{false, true, true}, add_gemm_f16_nt_128x64},
// {GemmParams{2560, 2816, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x256},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_256x128},
// {GemmParams{1280, 1408, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x128},
// {GemmParams{1280, 704, 4096}, LayoutConfig{true, false, true}, add_gemm_f16_tn_128x64},
// {GemmParams{2560, 2816, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x256},
// {GemmParams{2560, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_256x128},
// {GemmParams{1280, 1408, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x128},
// {GemmParams{1280, 704, 4096}, LayoutConfig{true, true, true}, add_gemm_f16_tt_128x64},
// clang-format on
};
......
......@@ -20,15 +20,15 @@ namespace instance {
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using gemm_f16_tn_256x256 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// using gemm_f16_tn_256x256 = std::tuple<
// // clang-format off
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on
// >;
using gemm_f16_tn_256x128 = std::tuple<
// clang-format off
......@@ -40,45 +40,45 @@ using gemm_f16_tn_256x128 = std::tuple<
// clang-format on
>;
using gemm_f16_tn_128x128 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// using gemm_f16_tn_128x128 = std::tuple<
// // clang-format off
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on
// >;
using gemm_f16_tn_128x64 = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// using gemm_f16_tn_128x64 = std::tuple<
// // clang-format off
// //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on
// >;
void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_256x256{});
}
// void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_256x256{});
// }
void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_256x128{});
}
void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_128x128{});
}
// void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_128x128{});
// }
void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_128x64{});
}
// void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_128x64{});
// }
} // namespace instance
} // namespace device
......
......@@ -27,13 +27,13 @@ using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances);
// void add_gemm_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
// void add_gemm_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances);
// void add_gemm_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances);
} // namespace instance
} // namespace device
......
......@@ -30,15 +30,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using gemm_f16_tn_256x256 = std::tuple<
// clang-format off
//##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
// using gemm_f16_tn_256x256 = std::tuple<
// // clang-format off
// //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on
// >;
using gemm_f16_tn_256x128 = std::tuple<
// clang-format off
......@@ -50,45 +50,45 @@ using gemm_f16_tn_256x128 = std::tuple<
// clang-format on
>;
using gemm_f16_tn_128x128 = std::tuple<
// clang-format off
//##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
using gemm_f16_tn_128x64 = std::tuple<
// clang-format off
//##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_256x256{});
}
// using gemm_f16_tn_128x128 = std::tuple<
// // clang-format off
// //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on
// >;
// using gemm_f16_tn_128x64 = std::tuple<
// // clang-format off
// //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on
// >;
// void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_256x256{});
// }
void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_256x128{});
}
void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_128x128{});
}
// void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_128x128{});
// }
void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
{
add_device_operation_instances(instances, gemm_f16_tn_128x64{});
}
// void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
// {
// add_device_operation_instances(instances, gemm_f16_tn_128x64{});
// }
} // namespace instance
} // namespace device
......
......@@ -11,13 +11,13 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances);
// void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
// void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances);
void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances);
// void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances);
} // namespace instance
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment