Commit 3037f6c7 authored by Adam Osewski's avatar Adam Osewski
Browse files

Compile only 256x128 tile size wavelet gemm.

parent 3002a1a5
...@@ -30,15 +30,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -30,15 +30,15 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// using gemm_f16_tn_256x256 = std::tuple< using gemm_f16_tn_256x256 = std::tuple<
// // clang-format off // clang-format off
// //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on // clang-format on
// >; >;
using gemm_f16_tn_256x128 = std::tuple< using gemm_f16_tn_256x128 = std::tuple<
// clang-format off // clang-format off
...@@ -50,45 +50,45 @@ using gemm_f16_tn_256x128 = std::tuple< ...@@ -50,45 +50,45 @@ using gemm_f16_tn_256x128 = std::tuple<
// clang-format on // clang-format on
>; >;
// using gemm_f16_tn_128x128 = std::tuple< using gemm_f16_tn_128x128 = std::tuple<
// // clang-format off // clang-format off
// //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on // clang-format on
// >; >;
// using gemm_f16_tn_128x64 = std::tuple< using gemm_f16_tn_128x64 = std::tuple<
// // clang-format off // clang-format off
// //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //##################### | ALayout| BLayout| CLayout| AData| BData| AccData| CShuffle| CData| A| B| C| GEMM| NumGemmK| ABBlockTransfer| BlockGemm| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##################### | | | | Type| Type| Type| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| ThreadGroupSize| ThreadGroupSize| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##################### | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##################### | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> DeviceGemm_Xdl_WaveletModel_CShuffle< Row, Col, Row, F16, F16, F32, F16, F16, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// // clang-format on // clang-format on
// >; >;
// void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances) void add_gemm_wavelet_f16_tn_256x256(std::vector<std::unique_ptr<BaseOperator>>& instances)
// { {
// add_device_operation_instances(instances, gemm_f16_tn_256x256{}); add_device_operation_instances(instances, gemm_f16_tn_256x256{});
// } }
void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances) void add_gemm_wavelet_f16_tn_256x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
{ {
add_device_operation_instances(instances, gemm_f16_tn_256x128{}); add_device_operation_instances(instances, gemm_f16_tn_256x128{});
} }
// void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances) void add_gemm_wavelet_f16_tn_128x128(std::vector<std::unique_ptr<BaseOperator>>& instances)
// { {
// add_device_operation_instances(instances, gemm_f16_tn_128x128{}); add_device_operation_instances(instances, gemm_f16_tn_128x128{});
// } }
// void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances) void add_gemm_wavelet_f16_tn_128x64(std::vector<std::unique_ptr<BaseOperator>>& instances)
// { {
// add_device_operation_instances(instances, gemm_f16_tn_128x64{}); add_device_operation_instances(instances, gemm_f16_tn_128x64{});
// } }
} // namespace instance } // namespace instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment