From 7dcb14d9d495e6329477f5ee27a27cfbb4ce49a6 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 5 Sep 2023 11:29:44 -0700 Subject: [PATCH 001/112] fix syntax (#890) --- .../gpu/device/impl/device_image_to_column_impl.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp index 89fcbca1a..19f126e66 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -151,8 +151,9 @@ struct DeviceImageToColumnImpl remove_cvref_t; using OutputGridDesc = remove_cvref_t; - using Block2ETileMap = remove_cvref_t(OutputGridDesc{}))>; + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; using GridwiseImageToColumnKernel = GridwiseImageToColumn Date: Tue, 5 Sep 2023 21:25:28 +0200 Subject: [PATCH 002/112] Add contribution guidelines to the documentation (#843) Add contribution guidelines to the documentation --- docs/Contributors_Guide.rst | 100 ++++++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 3 deletions(-) diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index b2ddff398..41cb8f191 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -2,7 +2,101 @@ Contributor's Guide =================== -Pull-request guidelines -======================= +This chapter explains how to get started contributing to the Composable Kernel project and what are +the contributing rules. -[TODO] +Getting started +=============== + +#. **Documentation:** Before contributing to the library, familiarize yourself with the + `Composable Kernel User Guide `_. + It provides insight into the core concepts, environment configuration, and steps to obtain or + build the library. You can also find some of this information in the + `README file `_ + on the project's GitHub page. +#. **Additional reading:** We also recommend reading a `blog post + `_ + from the AMD Community portal. It offers a deeper understanding of the library's objectives and + showcases its performance capabilities. +#. **General information:** For broader information about AMD products, consider exploring the + `AMD Developer Central portal `_. + +How do I contribute +=================== + +We deeply value contributions from our users. You can make an impact by reporting issues or +proposing code enhancements through pull requests. + +Reporting issues +---------------- + +We use `Github issues `_ +to track public bugs and enhancement requests. + +If you encounter an issue with the library, please check if the problem has already been +reported by searching existing issues on GitHub. If your issue seems unique, please submit a new +issue. All reported issues must include: + +* A comprehensive description of the problem, including: + + * What did you observe? + * Why do you think it is a bug (if it seems like one)? + * What did you expect to happen? What would indicate the resolution of the problem? + * Are there any known workarounds? + +* Your configuration details, including: + + * Which GPU are you using? + * Which OS version are you on? + * Which ROCm version are you using? + * Are you using a Docker image? If so, which one? + +* Steps to reproduce the issue, including: + + * What actions trigger the issue? What are the reproduction steps? + + * If you build the library from scratch, what CMake command did you use? + + * How frequently does this issue happen? Does it reproduce every time? Or is it a sporadic issue? + +Before sumbitting any issue, ensure you have addressed all relevant questions from the checklist. + +Creating Pull Requests +---------------------- + +You can submit `Pull Requests (PR) on GitHub +`_. + +All contributors are required to develop their changes on a separate branch and then create a +pull requrest to merge their changes into the `develop` branch, which is the default +development branch in the Composable Kernel project. All external contributors must use their own +forks of the project to develop their changes. + +When submitting a Pull Request you should: + +* Describe the change providing information about the motivation for the change and a general + description of all code modifications. + +* Verify and test the change: + + * Run any relevant existing tests. + * Write new tests if added functionality is not covered by current tests. + +* Ensure your changes align with the coding style defined in the ``.clang-format`` file located in + the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We + highly recommend contributors utilize this method to maintain consistent code formatting. + Instructions on setting up `pre-commit` can be found in the project's + `README file `_ + +* Link your PR to any related issues: + + * If there is an issue that is resolved by your change, please provide a link to the issue in + the description of your pull request. + +* For larger contributions, structure your change into a sequence of smaller, focused commits, each + addressing a particular aspect or fix. + +Following the above guidelines ensures a seamless review process and faster assistance from our +end. + +Thank you for your commitment to enhancing the Composable Kernel project! We look forward to collaborating with you. -- GitLab From aae4df559615e611bafd5850be5168287d97edd9 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 5 Sep 2023 14:29:37 -0700 Subject: [PATCH 003/112] set warnings as errors in doxygen (#864) --- cmake/DoxygenDoc.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmake/DoxygenDoc.cmake b/cmake/DoxygenDoc.cmake index 2e3669fcd..c91308b5b 100644 --- a/cmake/DoxygenDoc.cmake +++ b/cmake/DoxygenDoc.cmake @@ -309,6 +309,8 @@ XML_OUTPUT XML_PROGRAMLISTING ) +set(WARN_AS_ERROR YES) + set(DOXYGEN_CONFIG_FILE "${CMAKE_CURRENT_BINARY_DIR}/doxygen/doxygen.conf" CACHE PATH "Path to generated doxygen configuration file") function(add_doxygen_doc) -- GitLab From a61b8b785eae5f4432fbf62e85b3b119cf019ea0 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 6 Sep 2023 09:59:20 -0500 Subject: [PATCH 004/112] Fixed fp8 gemm (#882) * add generic instances; fixed initi with fp8 * fixed comment --------- Co-authored-by: Jing Zhang --- .../library/utility/host_tensor_generator.hpp | 35 ++++++++++++++++ ...f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp | 16 ++++++++ ...f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp | 18 ++++++++- ...dl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp | 12 ++++++ ...dl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp | 40 ++++++++++++------- 5 files changed, 106 insertions(+), 15 deletions(-) diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp index 31ff13aec..2fdb0b141 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -95,6 +95,22 @@ struct GeneratorTensor_2 } }; +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f8_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; +#endif + template struct GeneratorTensor_3 { @@ -127,6 +143,25 @@ struct GeneratorTensor_3 } }; +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f8_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; +#endif + template struct GeneratorTensor_4 { diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp index 7946bafc2..10c445346 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp @@ -33,6 +33,18 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Row, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances = std::tuple< // clang-format off @@ -73,6 +85,10 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_m PassThrough, MultiplyAdd>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_generic_instances{}); + add_device_operation_instances( instances, device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp index 99db9f427..71a3e4249 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp @@ -33,6 +33,19 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_generic_instances = + std::tuple< + // clang-format off + // M/N/K padding + // N % 8 == 0 && K % 1 == 0 + //##############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 2, 1, 32>, 1> + // clang-format on + >; + using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances = std::tuple< // clang-format off @@ -55,7 +68,6 @@ using device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_i DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 64>, 1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1>, DeviceGemmMultipleD_Xdl_CShuffle< Row, Col, Row_Tuple, Row, F16, F8, F32, F32, F32_Tuple, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 2, 1, 32>, 1> - // clang-format on >; @@ -72,6 +84,10 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_m PassThrough, MultiplyAdd>>>& instances) { + add_device_operation_instances( + instances, + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_generic_instances{}); + add_device_operation_instances( instances, device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp index c0b6ec67b..6168695b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp @@ -30,6 +30,16 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, true, 1, 1, S<1, 16, 1, 8>, 2, F16> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances = std::tuple< // clang-format off @@ -61,6 +71,8 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances( DeviceGemmSplitK>>& instances) { + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_generic_instances{}); add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp index fa38640ae..a2dd9532a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp @@ -27,7 +27,17 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_generic_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 2, F16> + // clang-format on + >; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< @@ -36,19 +46,19 @@ using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances = std::tuple< //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type| //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, - DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16> + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNKPadding, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16> // clang-format on >; @@ -57,6 +67,8 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( DeviceGemmSplitK>>& instances) { + add_device_operation_instances(instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_generic_instances{}); add_device_operation_instances(instances, device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances{}); } -- GitLab From 3786bfe1cce4b895494261a19a31e41f6b43fc7a Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 6 Sep 2023 10:15:52 -0500 Subject: [PATCH 005/112] added padding of K into gemm_v2r3 (#887) * added kpad support into v2r3 * add generic instances * fixed comments * fixed mnk padding * Update device_batched_gemm_xdl.hpp --------- Co-authored-by: Jing Zhang --- .../device/impl/device_batched_gemm_xdl.hpp | 14 +++--- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 45 +++++++++++++++++-- ...m_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp | 12 +++++ ...m_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp | 12 +++++ 4 files changed, 72 insertions(+), 11 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 408c56d49..303eba156 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -185,7 +185,7 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm" << " NumGemmKPrefetchStage: " << NumGemmKPrefetchStage << ", " diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 31aee70db..b6c146ae6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -194,7 +194,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 StrideC{StrideC_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, - K0{CalculateK0(K)} + K0{CalculateK0(K_)} { } @@ -383,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1); return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -840,7 +840,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) { return transform_tensor_descriptor( a_grid_desc_m_k, @@ -874,7 +892,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } }(); - if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) { return transform_tensor_descriptor( b_grid_desc_k_n, diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp index 2f42f62b0..dc7de8c68 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp @@ -25,6 +25,16 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances = std::tuple< // clang-format off @@ -100,6 +110,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances( DeviceBatchedGemm>>& instances) { + add_device_operation_instances( + instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_generic_instances{}); add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances{}); } diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp index 10b4cea7d..cccad7ca1 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp @@ -25,6 +25,16 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances = std::tuple< + // clang-format off + //#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumGemmK| LoopScheduler| Pipeline| + //#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| Prefetch| | | + //#################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| Stage | | | + //#################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances = std::tuple< // clang-format off @@ -88,6 +98,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances( DeviceBatchedGemm>>& instances) { + add_device_operation_instances( + instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_generic_instances{}); add_device_operation_instances(instances, device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances{}); } -- GitLab From 37a8c1f756f94e95b996fc3c26ff39fcf4e435c5 Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Wed, 6 Sep 2023 18:44:09 +0200 Subject: [PATCH 006/112] Redesign the DPP8 GEMM kernel to use warp-wise component (#863) * Redesign the DPP8 GEMM kernel to use warp-wise component * Review: Improve error messages * Review: Remove unnecessary empty lines * Review: Fix M, N per thread names * Review: Rename mfma_input_type to dpp_input_type * Review: Fix tensor adaptor; remove unnecessary element * Review: Remove calls to dpp_gemm's MakeCDescriptor * Review: Add blockwise doc, change function names to include dimension names * Review: Remove duplicated code; Move Block2CtileMap alias to the top of the file * Review: Add __restrict__ keywords * Review: Use MatrixPadder for padding A, B, C matrices * Review: Remove hardcoded datatypes * Review: Change names from FloatX to XDataType * Review: Introduce AK0 and BK0 instead of a single K0 * Review: Remove construction of dpp_datatypes object * Review: Rename DppInstrRunner to DppLanegroupGemm --- example/01_gemm/CMakeLists.txt | 3 +- example/01_gemm/gemm_dl_dpp8_fp16.cpp | 37 - example/01_gemm/gemm_dpp_fp16.cpp | 39 + .../gpu/block/blockwise_gemm_dl_dpp8.hpp | 370 --------- .../gpu/block/blockwise_gemm_dpp.hpp | 348 +++++++++ .../gpu/block/blockwise_gemm_xdlops.hpp | 16 +- .../gpu/device/gemm_dl_algorithm.hpp | 18 - .../gpu/device/impl/device_gemm_dl.hpp | 27 +- .../gpu/device/impl/device_gemm_dl_dpp8.hpp | 133 ---- .../gpu/device/impl/device_gemm_dpp.hpp | 271 +++++++ .../gpu/grid/gridwise_gemm_dl_v1r3.hpp | 71 +- .../gpu/grid/gridwise_gemm_dpp.hpp | 701 ++++++++++++++++++ .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 3 +- .../thread/threadwise_contraction_dl_dpp8.hpp | 136 ---- .../ck/tensor_operation/gpu/warp/dpp_gemm.hpp | 322 ++++++++ include/ck/utility/amd_gemm_dpp.hpp | 56 +- include/ck/utility/inner_product_dpp8.hpp | 4 + include/ck/utility/loop_scheduler.hpp | 26 + .../tensor_operation_instance/gpu/gemm.hpp | 16 +- .../gpu/gemm/CMakeLists.txt | 8 +- ..._dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp | 61 -- ..._dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp | 61 -- ..._dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp | 60 -- ..._dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp | 61 -- ...gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp | 58 ++ ...gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp | 58 ++ ...gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp | 57 ++ ...gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp | 58 ++ 28 files changed, 2030 insertions(+), 1049 deletions(-) delete mode 100644 example/01_gemm/gemm_dl_dpp8_fp16.cpp create mode 100644 example/01_gemm/gemm_dpp_fp16.cpp delete mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp delete mode 100644 include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp delete mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp delete mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp create mode 100644 include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp create mode 100644 include/ck/utility/loop_scheduler.hpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index a671877f6..3dc2a0966 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -6,8 +6,7 @@ if(DL_KERNELS) if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) add_dependencies(example_gemm_dl example_gemm_dl_fp16) - add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16) + add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) endif() if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) diff --git a/example/01_gemm/gemm_dl_dpp8_fp16.cpp b/example/01_gemm/gemm_dl_dpp8_fp16.cpp deleted file mode 100644 index ea0ba3907..000000000 --- a/example/01_gemm/gemm_dl_dpp8_fp16.cpp +++ /dev/null @@ -1,37 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "common.hpp" - -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" - -using ADataType = ck::half_t; -using BDataType = ck::half_t; -using CDataType = ck::half_t; -using AccDataType = float; - -using ALayout = Col; -using BLayout = Row; -using CLayout = Row; - -using AElementOp = PassThrough; -using BElementOp = PassThrough; -using CElementOp = PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8 -// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| -// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| -// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>; -// clang-format on - -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - -#include "run_gemm_example.inc" - -int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/01_gemm/gemm_dpp_fp16.cpp b/example/01_gemm/gemm_dpp_fp16.cpp new file mode 100644 index 000000000..7a9e3f618 --- /dev/null +++ b/example/01_gemm/gemm_dpp_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CDataType = ck::half_t; + +using F16 = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 64, 64, 64, 8, 2, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 5, 1>; +// // clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp deleted file mode 100644 index e527509f5..000000000 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp +++ /dev/null @@ -1,370 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_adaptor.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp" - -namespace ck { - -/** - * DPP8 version of blockwise GEMM algorithm. It uses DPP8 instruction modifier to limit - * the data loaded from LDS to registers. - * - * The algorithm groups threads into groups of size `dpp8::lane_group_size` and splits the matrix C - * between them in such a way that threads from the same group need the same chunk of either - * matrix A (or B, respectively). Without the usage of DPP8, each thread would need to load the - * whole chunk from LDS to its own register space. - * Usage of DPP8 modifiers allow each thread to load less data, exactly `1 / dpp8::lane_group_size` - * of the chunk, and then share that data with other threads from the same lane group. - * - * Assumptions coming from the usage of DPP8: - * 1. `BM10BN10ThreadClusterBM10Xs[1] == dpp8::lane_group_size` or - * `BM10BN10ThreadClusterBN10Xs[1] == dpp8::lane_group_size` - - * - it makes consecutive `dpp8::lane_group_size` threads use the same chunk of either - * matrix A or B; - * - based on these values we determine which matrix to share. - * 2. `BM1PerThreadBM11 % dpp8::lane_group_size == 0` (if sharing A) or - * `BN1PerThreadBN11 % dpp8::lane_group_size == 0` (if sharing B) - - * - we have to make sure that the data to split is divisible by the number of - * threads in the group. - * - * General algorithm: - * C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] - * A and B are visible to the whole block, C is distributed among each thread - * Assume: - * 1. A: - * 1. ABlockDesc_BK0_BM_BK1 is known at compile-time - * 2. ABlockBuffer is DynamicBuffer - * 2. B: - * 1. BBlockDesc_BK0_BN_BK1 is known at compile-time - * 2. BBlockBuffer is DynamicBuffer - * 3. C: - * 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time - * 2. CThreadBuffer is StaticBuffer - * 4. BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 - */ -template - typename BM10BN10ThreadClusterBN10Xs, // Sequence - index_t AThreadCopyScalarPerVector_BM11, - index_t BThreadCopyScalarPerVector_BN11, - typename enable_if::type = false> -struct BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0 -{ - using AIndex = MultiIndex<4>; - using BIndex = MultiIndex<4>; - using CIndex = MultiIndex<4>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); - static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); - static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); - static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); - - static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; - static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; - - static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; - static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; - - static constexpr index_t BM11 = BM1PerThreadBM11; - static constexpr index_t BN11 = BN1PerThreadBN11; - - static constexpr index_t BM1 = BM100 * BM101 * BM11; - static constexpr index_t BN1 = BN100 * BN101 * BN11; - - static constexpr index_t BM0 = BM / BM1; - static constexpr index_t BN0 = BN / BN1; - - // We assume that either `BM101` or `BN101` is equal to `dpp8::lane_group_size`. It makes all - // threads in a lane group need the same chunk of B or A matrices and we can share them using - // DPP. - static_assert(BM101 == dpp8::lane_group_size || BN101 == dpp8::lane_group_size); - static constexpr bool ShareB = BM101 == dpp8::lane_group_size ? true : false; - static constexpr bool ShareA = !ShareB; - - // If DPP shares A (B, respectively), lane group gets `BM1PerThreadBM11` (`BN1PerThreadBN11`, - // respectively) elements, so we split them between threads in lane group so each thread loads - // less data from LDS. - static constexpr index_t BM1PerThread = - ShareA ? BM1PerThreadBM11 / dpp8::lane_group_size : BM1PerThreadBM11; - static constexpr index_t BN1PerThread = - ShareB ? BN1PerThreadBN11 / dpp8::lane_group_size : BN1PerThreadBN11; - - __host__ __device__ static constexpr auto - MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) - { - const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( - a_block_desc_bk0_bm_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return a_block_bk0_bm0_bm1_bk1; - } - - __host__ __device__ static constexpr auto - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) - { - const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( - b_block_desc_bk0_bn_bk1, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - return b_block_desc_bk0_bn0_bn1_bk1; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM, BN] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = - make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; - } - - __host__ __device__ static constexpr auto - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() - { - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // lower: [BM0, BM1, BN0, BN1] - constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = - make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); - - return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; - } - - __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() - { - return Sequence{}; - } - - static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = - MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); - - static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = - MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); - - public: - __device__ BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0() - : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( - get_thread_local_1d_id())}, - a_thread_copy_{CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1()}, - b_thread_copy_{CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1()} - { - static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && - BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); - - static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == - BBlockDesc_BK0_BN_BK1{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && - BM10BN10ThreadClusterBN10Xs::Size() == 2, - "wrong!"); - } - - __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) - { - // lower: [BM0, BM1, BN0, BN1] - // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - constexpr auto adaptor0 = - MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); - - // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] - // upper: [Tid, BM0, BM11, BN0, BN11] - constexpr auto adaptor1 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), - make_pass_through_transform(BM0), - make_pass_through_transform(BM11), - make_pass_through_transform(BN0), - make_pass_through_transform(BN11)), - make_tuple( - Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); - - return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); - } - - __device__ AIndex CalculateAThreadOriginOnBlock_BK0_BM0_BM1_BK1() - { - const auto offsetBM0 = c_thread_origin_data_idx_[I0]; - // If sharing matrix A, we need a separate BM1 offset for each thread in lane group. - const auto offsetBM1 = ShareA ? c_thread_origin_data_idx_[I1] + - dpp8::get_thread_idx_in_lane_group() * BM1PerThread - : c_thread_origin_data_idx_[I1]; - return make_tuple(0, offsetBM0, offsetBM1, 0); - } - - __device__ BIndex CalculateBThreadOriginOnBlock_BK0_BN0_BN1_BK1() - { - const auto offsetBN0 = c_thread_origin_data_idx_[I2]; - // If sharing matrix B, we need a separate BN1 offset for each thread in lane group. - const auto offsetBN1 = ShareB ? c_thread_origin_data_idx_[I3] + - dpp8::get_thread_idx_in_lane_group() * BN1PerThread - : c_thread_origin_data_idx_[I3]; - return make_tuple(0, offsetBN0, offsetBN1, 0); - } - - template - __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, - const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - auto a_thread_buf = make_static_buffer( - a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); - - constexpr auto threadwise_contraction = - ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< - FloatA, - FloatB, - FloatC, - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - CThreadDesc_BM0_BM11_BN0_BN11, - Sequence, - Sequence<1, BM1PerThreadBM11>, - Sequence<1, BN1PerThreadBN11>, - ShareA>{}; - - static_for<0, BN0, 1>{}([&](auto bn0) { - static_for<0, BM0, 1>{}([&](auto bm0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - - static_for{}([&](auto bk0) { - a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, - make_tuple(bk0, bm0, I0, I0), - a_block_buf, - a_thread_desc_bk0_bm0_bm1_bk1_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, - make_tuple(bk0, bn0, I0, I0), - b_block_buf, - b_thread_desc_bk0_bn0_bn1_bk1_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - threadwise_contraction.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(bm0, I0, bn0, I0)); - }); - }); - }); - } - - private: - // A[BK0, BM0, BM1, BK1] - static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - // B[BK0, BN0, BN1, BK1] - static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatA, - FloatA, - decltype(a_block_desc_bk0_bm0_bm1_bk1_), - decltype(a_thread_desc_bk0_bm0_bm1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BM1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< - FloatB, - FloatB, - decltype(b_block_desc_bk0_bn0_bn1_bk1_), - decltype(b_thread_desc_bk0_bn0_bn1_bk1_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, BN1PerThread, BK1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder - - CIndex c_thread_origin_data_idx_; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp new file mode 100644 index 000000000..d62ed4b15 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp" + +namespace ck { + +/** + * Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each + * thread by sharing the data between threads in a lanegroup. + * + * In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are + * `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one. + * In total, the algorithm runs using + * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. + */ +template +struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto dpp_gemm = DppGemm{}; + + static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M(); + const auto dpp_a_idx_k = dpp_a_idx[I0]; + const auto dpp_a_idx_m = dpp_a_idx[I1]; + return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k); + } + + __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N(); + const auto dpp_b_idx_k = dpp_b_idx[I0]; + const auto dpp_b_idx_n = dpp_b_idx[I1]; + return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); + const auto blk_m_offset = blk_idx[I0]; + const auto blk_n_offset = blk_idx[I1]; + + constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_m_offset))[I0]; + const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_n_offset))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "Wrong! Block descriptors should be known at the time of compilation."); + +#if defined(__HIP_DEVICE_COMPILE__) + // Host wave size can be different than the device one and this assert could fail for host, + // but it does matter only for device. + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +#endif + + static_assert(MPerBlock % (MPerDpp * MRepeat) == 0, + "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."); + static_assert(NPerBlock % (NPerDpp * NRepeat) == 0, + "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return c_block_desc_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + return c_block_desc_g_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return c_grid_desc_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return c_grid_desc_g_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using dpp_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + dpp_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegDpp] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, dpp_gemm.GetRegSizePerDpp())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()}; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index d5a64d7aa..1fee9c322 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -4,27 +4,13 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" namespace ck { -enum struct LoopScheduler -{ - Default, - Interwave, -}; - -constexpr LoopScheduler make_default_loop_scheduler() -{ -#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING - return LoopScheduler::Interwave; -#else - return LoopScheduler::Default; -#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING -} - template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) diff --git a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp b/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp deleted file mode 100644 index 16ca582b8..000000000 --- a/include/ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp +++ /dev/null @@ -1,18 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -namespace ck { -namespace tensor_operation { -namespace device { - -enum struct GemmDlAlgorithm -{ - Default, // Uses DOT vector instructions - Dpp8, // Uses DOT vector instructions with DPP8 SEL modifier to reduce data loads from LDS -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index eedf384cd..d9f8324bd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -11,7 +11,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/host_utility/device_prop.hpp" @@ -60,7 +59,6 @@ template < typename CThreadTransferSrcDstAccessOrder, index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferDstScalarPerVector, - GemmDlAlgorithm GemmDlAlg = GemmDlAlgorithm::Default, enable_if_t< is_same_v && is_same_v && @@ -238,8 +236,7 @@ struct DeviceGemmDl : public DeviceGemm; + CThreadTransferDstScalarPerVector>; using AGridDesc_K0_M0_M1_K1 = decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); @@ -375,8 +372,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -402,8 +398,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, true, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -429,8 +424,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - true, - GemmDlAlg>; + true>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -456,8 +450,7 @@ struct DeviceGemmDl : public DeviceGemm, remove_reference_t, false, - false, - GemmDlAlg>; + false>; ave_time = launch_and_time_kernel(stream_config, kernel, @@ -492,16 +485,6 @@ struct DeviceGemmDl : public DeviceGemm -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -template < - typename ADataType, - typename BDataType, - typename CDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout, - typename AElementwiseOperation, - typename BElementwiseOperation, - typename CElementwiseOperation, - GemmSpecialization GemmSpec, - index_t BlockSize, - index_t MPerBlock, - index_t NPerBlock, - index_t K0PerBlock, - index_t K1, - index_t M1PerThread, - index_t N1PerThread, - index_t KPerThread, - typename M1N1ThreadClusterM1Xs, - typename M1N1ThreadClusterN1Xs, - typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, - typename ABlockTransferThreadClusterArrangeOrder, - typename ABlockTransferSrcAccessOrder, - typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, - typename ABlockTransferSrcVectorTensorContiguousDimOrder, - typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, - typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, - typename BBlockTransferThreadClusterArrangeOrder, - typename BBlockTransferSrcAccessOrder, - typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, - typename BBlockTransferSrcVectorTensorContiguousDimOrder, - typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, - typename CThreadTransferSrcDstAccessOrder, - index_t CThreadTransferSrcDstVectorDim, - index_t CThreadTransferDstScalarPerVector, - enable_if_t< - is_same_v && - is_same_v && - is_same_v, - bool> = false> -struct DeviceGemmDlDpp8 : public DeviceGemmDl - -{ - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceGemmDlDpp8" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock << ", " - << K1 << ", " - << M1PerThread << ", " - << N1PerThread << ", " - << KPerThread - << ">"; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp new file mode 100644 index 000000000..162f0690b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmDpp : public DeviceGemm +{ + using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< + BlockSize, + ADataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + KPerBlock, + MPerDpp, + NPerDpp, + AK1, + BK1, + MDppPerWave, + NDppPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::get_device_name() == "gfx1030") + { + return GridwiseGemm::CheckValidity(karg); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmDpp" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerDpp << ", " + << NPerDpp << ", " + << MDppPerWave << ", " + << MDppPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 1d1bb6ed2..1da723697 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -7,11 +7,9 @@ #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp" #include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" @@ -19,8 +17,6 @@ namespace ck { -using GemmDlAlgorithm = tensor_operation::device::GemmDlAlgorithm; - template + bool HasDoubleTailKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -43,13 +38,6 @@ __global__ void const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map) { -// DPP8 is currently only supported on gfx1030 -#if !defined(__gfx1030__) - if(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return; - } -#endif constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -100,8 +88,7 @@ template + index_t CThreadTransferDstScalarPerVector> struct GridwiseGemmDl_km_kn_mn_v1r3 { static constexpr auto I0 = Number<0>{}; @@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 c_grid_desc_m_n); } - template - __host__ __device__ static constexpr auto GetBlockwiseGemm() - { - if constexpr(GemmDlAlg == GemmDlAlgorithm::Dpp8) - { - return BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - else - { - return BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< - BlockSize, - FloatAB, - FloatAB, - FloatAcc, - ABlockDesc_BK0_BM_BK1, - BBlockDesc_BK0_BN_BK1, - M1PerThreadM111, - N1PerThreadN111, - KPerThread, - M11N11ThreadClusterM110Xs, - M11N11ThreadClusterN110Xs, - M1PerThreadM111, - N1PerThreadN111>{}; - } - } - using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); using CGridDesc_M0_M10_M11_N0_N10_N11 = @@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register const auto blockwise_gemm = - GetBlockwiseGemm(); + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp new file mode 100644 index 000000000..d39f0ba30 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -0,0 +1,701 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( + GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA)); + const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane( + GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane( + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + static constexpr auto max_lds_align = math::lcm(AK1, BK1); + + using ThisThreadBlock = ThisThreadBlock; + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); } + __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + AK0{CalculateAK0(K)}, + BK0{CalculateBK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t AK0; + index_t BK0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const ABDataType* p_a_grid_, + const ABDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ABDataType* p_a_grid; + const ABDataType* p_b_grid; + CDataType* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType); + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "Wrong! AK1 must be known at the time of compilation."); + static_assert(is_known_at_compile_time>::value, + "Wrong! BK1 must be known at the time of compilation."); + + static_assert( + MPerBlock % (MPerDpp * MDppPerWave) == 0, + "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."); + static_assert( + NPerBlock % (NPerDpp * NDppPerWave) == 0, + "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."); + + static_assert( + KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."); + + static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! AK1Value must be divisible by " + "ABlockTransferDstScalarPerVector_K1"); + + static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! BK1Value must be divisible by " + "BBlockTransferDstScalarPerVector_K1"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if(problem.K % KPerBlock != 0) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = problem.K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const auto num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + + using BlockwiseGemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + } + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + __device__ static auto + MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __device__ static auto + MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(BK0, BK1Value))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + } + + __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[AK0PerBlock, MPerBlock] is in LDS + // b_mtx[BK0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + auto blockwise_gemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0); + // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock) + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index d1209636d..754a3e89c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -4,7 +4,8 @@ #pragma once #include "ck/utility/common_header.hpp" -#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp deleted file mode 100644 index d0d214381..000000000 --- a/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp +++ /dev/null @@ -1,136 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/utility/amd_gemm_dpp.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/utility/inner_product_dpp8.hpp" -#include "ck/utility/math.hpp" - -namespace ck { - -/** - * Threadwise contraction using dot instructions with DPP8 modifier. - * - * Assumptions: - * 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1` - * are known at compile-time; - * 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time; - * 3. `TM0` is equal to 1 and `TN0` is equal to 1; - * 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by - * the size of the lane group (`dpp8::lane_group_size`). - */ -template ::type = false> -struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 -{ - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - - static constexpr index_t TK0 = TKLengths{}[I0]; - static constexpr index_t TK1 = TKLengths{}[I1]; - static constexpr index_t TM0 = TMLengths{}[I0]; - static constexpr index_t TM1 = TMLengths{}[I1]; - static constexpr index_t TN0 = TNLengths{}[I0]; - static constexpr index_t TN1 = TNLengths{}[I1]; - - static_assert(TM0 == 1 && TN0 == 1); - - static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) || - (!ShareA && TN1 % dpp8::lane_group_size == 0)); - static constexpr index_t shared_elems_per_lane = - ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size; - - __device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() - { - static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && - BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && - CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, - "wrong!"); - } - - template - __device__ static void Run(const ABuffer& a_buf, - AOriginIdx, - const BBuffer& b_buf, - BOriginIdx, - CBuffer& c_buf, - COriginIdx) - { - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - - static_assert( - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - is_same, remove_cvref_t>::value && - "wrong! inconsistent type"); - - constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); - constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); - constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - - static_for<0, TK0, 1>{}([&](auto tk0) { - static_for<0, TM1, 1>{}([&](auto tm1) { - static_for<0, TN1, 1>{}([&](auto tn1) { - vector_type a_vec; - vector_type b_vec; - - static_for<0, TK1, 1>{}([&](auto tk1) { - constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1; - constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( - a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1)); - - constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane; - constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( - b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1)); - - a_vec.template AsType()(tk1) = a_buf[Number{}]; - b_vec.template AsType()(tk1) = b_buf[Number{}]; - }); - - using a_vector_t = typename vector_type::type; - using b_vector_t = typename vector_type::type; - - constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( - c_origin_idx + make_multi_index(0, tm1, 0, tn1)); - - constexpr int src_lane = - ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size - : (tn1 / shared_elems_per_lane) % dpp8::lane_group_size; - - dpp8::inner_product_dpp( - a_vec.template AsType()[I0], - b_vec.template AsType()[I0], - c_buf(Number{})); - }); - }); - }); - } -}; - -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp new file mode 100644 index 000000000..03a4d17c9 --- /dev/null +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +enum struct DppInstr +{ + dpp8_f16_16x16x2 = 0, + dpp8_f16_8x32x2, + dpp8_f16_32x8x2 +}; + +/** + * Structure representing DPP GEMM executed by a single wavefront. + * + * Each structure instantiation must contain the following fields: + * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the + * number of threads in a wavefront; + * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, + * it's 8 in case of DPP8; + * - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - m_per_lanegroup - size along M dimension that is processed by a single lanegroup; + * - n_per_lanegroup - size along N dimension that is processed by a single lanegroup; + * - m_per_thread - size along M dimension of the tile calculated by a single thread; + * - n_per_thread - size along N dimension of the tile calculated by a single thread; + * - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation; + * - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers. + * + * Not all the combinarions are supported now, for current restrictions see the static asserts + * in the DppSelector's contructor. + */ +template +struct dpp_type; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 32; + static constexpr index_t n_per_wave = 8; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template +struct DppSelector +{ + template + static constexpr auto GetDpp(); + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_16x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_32x8x2; + } + + static constexpr auto selected_dpp = dpp_type()>{}; + + __host__ __device__ constexpr DppSelector() + { + static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0); + static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0); + + static_assert(selected_dpp.k_per_dpp % 2 == 0); + + static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0); + constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size; + constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave; + constexpr index_t num_dpp_c_elems = + selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup; + static_assert(num_wave_c_elems % num_dpp_c_elems == 0); + static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems); + + if constexpr(selected_dpp.share_a) + { + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + } + else + { + static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread); + } + + // Below checks come from the restrictions of the current implementation, could be removed + // in the future when the implementation is more generalized. + static_assert(selected_dpp.share_a); + static_assert(selected_dpp.n_per_thread == 1); + static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size); + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup == + selected_dpp.n_per_thread * selected_dpp.lanegroup_size); + } + + static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } +}; + +template +struct DppGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr DppGemm() + { + static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32, + "MPerDpp must be either 8, 16 or 32."); + static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32, + "NPerDpp must be either 8, 16 or 32."); + + static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); + } + + __device__ static constexpr index_t GetRegSizePerDpp() + { + return MPerDpp * NPerDpp / dpp_instr.wave_size; + } + + template + __device__ void + Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "base BaseType must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { + dpp_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + }); + } + + __device__ static auto GetLaneIdInWave() + { + return get_thread_local_1d_id() % dpp_instr.wave_size; + } + + __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; } + + __device__ static auto GetLaneIdInLaneGroup() + { + return get_thread_local_1d_id() % dpp_instr.lanegroup_size; + } + + __device__ static auto GetLaneGroupIdInWave() + { + return GetLaneIdInWave() / dpp_instr.lanegroup_size; + } + + __device__ static auto GetDppOpIdx() + { + const auto lanegroupId = GetLaneGroupIdInWave(); + + constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup, + dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( + make_multi_index(lanegroupId)); + + const auto m_dpp_idx = dpp_idx[I0]; + const auto n_dpp_idx = dpp_idx[I1]; + + return make_tuple(m_dpp_idx, n_dpp_idx); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M() + { + const auto laneId = get_thread_local_1d_id(); + const auto wave_row = laneId / dpp_instr.n_per_wave; + auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); + return make_tuple(0, m_idx % dpp_instr.m_per_wave); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N() + { + const auto laneId = get_thread_local_1d_id(); + return make_tuple(0, laneId % dpp_instr.n_per_wave); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + const auto dpp_op_idx = GetDppOpIdx(); + + const auto m_dpp_op_idx = dpp_op_idx[I0]; + const auto n_dpp_op_idx = dpp_op_idx[I1]; + + index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); + index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup; + + return CIndex{m_offset, n_offset}; + } + + static constexpr auto dpp = DppSelector{}; + + static constexpr auto dpp_instr = dpp.selected_dpp; + + static constexpr auto K0PerDpp = 1; + static constexpr auto K1PerDpp = dpp.GetK1PerDpp(); + + __host__ __device__ static constexpr auto GetCMNThreadBlkLengths() + { + return make_tuple(Number{}, Number{}); + } +}; + +} // namespace ck diff --git a/include/ck/utility/amd_gemm_dpp.hpp b/include/ck/utility/amd_gemm_dpp.hpp index 8d6c7eede..a28292dad 100644 --- a/include/ck/utility/amd_gemm_dpp.hpp +++ b/include/ck/utility/amd_gemm_dpp.hpp @@ -5,17 +5,63 @@ #include "ck/utility/common_header.hpp" #include "ck/utility/math.hpp" -#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/inner_product_dpp8.hpp" namespace ck { namespace dpp8 { -/// Number of lanes that can share data using DPP8 modifiers. -constexpr index_t lane_group_size = 8; +template +struct dpp_datatypes; -__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; } -__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; } +template <> +struct dpp_datatypes +{ + // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a + // single instruction. + using a_dtype = half_t; + using b_dtype = half_t; + using c_dtype = float; + static constexpr index_t k_per_instr = 2; +}; + +template +struct DppLanegroupGemm +{ + using datatypes_conf = dpp_datatypes; + using ADataType = typename datatypes_conf::a_dtype; + using BDataType = typename datatypes_conf::b_dtype; + using CDataType = typename datatypes_conf::c_dtype; + + __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec) + { + constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread; + + const vector_type a_vector{a_vec}; + const vector_type b_vector{b_vec}; + + static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) { + float c = c_vec.template AsType()(c_idx); + // Next `c_idx` implies that we need to pull data from the next lane. + constexpr index_t source_lane = c_idx; + static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) { + const auto a_k_vec = a_vector.template AsType()[k_chunk]; + const auto b_k_vec = b_vector.template AsType()[k_chunk]; + ck::dpp8:: + inner_product_dpp( + a_k_vec, b_k_vec, c); + }); + c_vec.template AsType()(c_idx) = c; + }); + } +}; } // namespace dpp8 diff --git a/include/ck/utility/inner_product_dpp8.hpp b/include/ck/utility/inner_product_dpp8.hpp index ccd7a4e62..f079e2ca6 100644 --- a/include/ck/utility/inner_product_dpp8.hpp +++ b/include/ck/utility/inner_product_dpp8.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once + #include "amd_gemm_dpp.hpp" #include "data_type.hpp" #include "type_convert.hpp" @@ -10,6 +11,9 @@ namespace ck { namespace dpp8 { +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + template __device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); diff --git a/include/ck/utility/loop_scheduler.hpp b/include/ck/utility/loop_scheduler.hpp new file mode 100644 index 000000000..b2eb0ddb9 --- /dev/null +++ b/include/ck/utility/loop_scheduler.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct LoopScheduler +{ + Default, + Interwave, +}; + +constexpr LoopScheduler make_default_loop_scheduler() +{ +#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING + return LoopScheduler::Interwave; +#else + return LoopScheduler::Default; +#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING +} + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index 9310fd433..b4fa74648 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -23,7 +23,7 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( std::vector>>& instances); @@ -38,7 +38,7 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( std::vector>>& instances); @@ -53,7 +53,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); @@ -68,7 +68,7 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( DeviceGemm>>& instances); -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances( +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); @@ -374,7 +374,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } @@ -385,7 +385,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); @@ -397,7 +397,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); } @@ -408,7 +408,7 @@ struct DeviceOperationInstanceFactory< #ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); - add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); } diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 5162617b1..48dd292f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -31,10 +31,10 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp) endif() list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp deleted file mode 100644 index 53fc74a39..000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instance.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp deleted file mode 100644 index 35d46b32f..000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instance.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<1, 1, 4, 2>, S<16, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<1, 1, 2, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp deleted file mode 100644 index 22e2e5f39..000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,60 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0 ,3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 2, 1>, S<0, 3, 1, 2>, S<1, 1, 2, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp deleted file mode 100644 index 49e9416c8..000000000 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - -// Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances = std::tuple< - // clang-format off - // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | - // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 8, 8, 8, 4, 2, 1, 8, 1, S<1, 8>, S<1, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0 ,3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<16, 1, 2, 2>, S<1, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<4, 1>, S<4, 1, 1, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 8, 2>, S<4, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 1, 8, 1, S<1, 8>, S<8, 1>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 8, 2, 4, 8, 1, S<2, 8>, S<4, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 8, 64, 16, 2, 8, 1, 1, S<1, 1>, S<8, 8>, S<2, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<2, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 16, 2, 4, 8, 1, S<2, 8>, S<8, 1>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 128, 64, 16, 2, 1, 8, 8, S<4, 8>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>, - DeviceGemmDlDpp8< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 16, 2, 4, 8, 8, S<2, 8>, S<16, 1>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 2, 2>, S<4, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances) -{ - add_device_operation_instances(instances, device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 000000000..80da6d4c3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_kn_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 4, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 4, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 4, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 4, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 4, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 000000000..97b85fd1a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_nk_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 8, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 8, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 8, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 8, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..370ebfccc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 4, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 4, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 4, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 4, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 000000000..605327549 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 8, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 8, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 8, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 8, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From 562b4cec485ff735a46f46dd7da894b0c3ebdf59 Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Fri, 8 Sep 2023 10:59:26 +0800 Subject: [PATCH 007/112] [Navi3x] Add fp16/int8 wmma conv forward instances (#746) * fix wmma gemm int8; add grouped conv int8 example * Add int8 gemm-bilinear instances * compile sanity check unknown * Sanity pass + clang-format * add int8 conv profiler instances * solve merge conflict --------- Co-authored-by: zjing14 Co-authored-by: Chao Liu --- example/02_gemm_bilinear/CMakeLists.txt | 3 + .../gemm_bilinear_wmma_int8.cpp | 304 ++++++++++++++++++ .../gpu/block/blockwise_gemm_wmma.hpp | 135 +++++--- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 4 +- .../element/binary_element_wise_operation.hpp | 7 + .../device_operation_instance_factory.hpp | 1 + .../gpu/gemm_bilinear.hpp | 76 +++++ ...evice_grouped_conv2d_fwd_wmma_instance.hpp | 134 ++++++++ .../gpu/grouped_convolution_forward.hpp | 32 ++ .../gpu/gemm_bilinear/CMakeLists.txt | 4 + ...uffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp | 89 +++++ ...uffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp | 89 +++++ ...uffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp | 89 +++++ ...uffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp | 115 +++++++ .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 3 + ...wd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 66 ++++ ...fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp | 66 ++++ .../profile_grouped_conv_fwd_impl.hpp | 2 +- profiler/src/profile_gemm_bilinear.cpp | 19 ++ 19 files changed, 1194 insertions(+), 44 deletions(-) create mode 100644 example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 9989e45e0..52e633052 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -5,6 +5,9 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) + add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) +endif() +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") set(target 1) endif() endforeach() diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp new file mode 100644 index 000000000..9f23ad265 --- /dev/null +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -0,0 +1,304 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +struct AlphaBetaAdd +{ + AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& e, const std::int32_t& c, const std::int8_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + int alpha_; + int beta_; +}; + +template +using S = ck::Sequence; + +using I8 = std::int8_t; +using I32 = std::int32_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DDataType = I8; +using EDataType = I8; + +using ALayout = Row; +using BLayout = Row; +using DLayout = Row; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, + ELayout, + ADataType, + BDataType, + ck::Tuple, + EDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 32, + 16, + 16, + 4, + 16, + 16, + 16, + 1, + 1, + S<2, 16, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + 1, + S<4, 1, 8>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 16, + 2, + 1, + 1, + 1, + S<1, 16, 1, 2>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + int alpha = 1; + int beta = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 5ec964bd3..b3d45f3d0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / WmmaK, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); }); - }); - }); + } + else + { + static_for<0, KPerBlock / WmmaK, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0), + b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0), + a_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } } protected: diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index 7c726cd85..1d73a7238 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1, stride=1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; @@ -616,7 +616,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // check if it's 1x1 conv for(index_t i = 0; i < NDimSpatial; ++i) { - const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; const index_t LeftPad = arg.input_left_pads_[i]; const index_t RightPad = arg.input_right_pads_[i]; diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 9ade02391..9fe0931cb 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -186,6 +186,13 @@ struct Bilinear y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); }; + template <> + __host__ __device__ constexpr void operator()( + std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const + { + y = type_convert(x0 + ck::type_convert(x1)); + }; + float alpha_; float beta_; }; diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 5096f5c30..84d31ce26 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -31,6 +31,7 @@ using F64_Tuple = ck::Tuple; using F32_Tuple = ck::Tuple; using I32_Tuple = ck::Tuple; using I32_F32_Tuple = ck::Tuple; +using I8_Tuple = ck::Tuple; using F32_F32_Tuple = ck::Tuple; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 9d8eaada2..387499e58 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -69,6 +69,58 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance PassThrough, Bilinear>>>& instances); +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances( + std::vector>>& instances); + // GEMM + Bilinear template && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs); + } + } return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp new file mode 100644 index 000000000..c9cf0f8e1 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GNHWC = ck::tensor_layout::convolution::GNHWC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using NHWGK = ck::tensor_layout::convolution::NHWGK; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_grouped_conv2d_fwd_wmma_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // blocksize=256 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_wmma_i8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // blocksize=256 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 2a0da10c6..576b9d898 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -145,6 +145,19 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances); void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( std::vector && is_same_v && @@ -414,6 +441,11 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); + } } else if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index b9f1210b2..aef8fe86d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -4,5 +4,9 @@ add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp + device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp ) endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp new file mode 100644 index 000000000..73ea9cac0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = std::int8_t; +using I32 = std::int32_t; +using I8_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) +using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + + // M/N/K padding + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp new file mode 100644 index 000000000..1f36113e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = std::int8_t; +using I32 = std::int32_t; +using I8_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) +using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + + // M/N/K padding + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp new file mode 100644 index 000000000..688c46336 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = std::int8_t; +using I32 = std::int32_t; +using I8_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) +using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances = std::tuple< + // clang-format off + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + + // M/N/K padding + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp new file mode 100644 index 000000000..5319bd860 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = std::int8_t; +using I32 = std::int32_t; +using I8_Tuple = ck::Tuple; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// e[m, n] = bilinear(a[m, k] * b[n, k], d[m, n]) +using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = std::tuple< + // clang-format off + // no padding + // N % 16 == 0 && K % 16 == 0 + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + // M/N/K padding + // N % 16 == 0 && K % 16 == 0 + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + // M/N/K padding + // N % 8 == 0 && K % 8 == 0 + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + + // M/N/K padding + // N % 8 == 0 && K % 8 == 0 + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>, + + // M/N/K padding + // N % 1 == 0 && K % 8 == 0 + //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1> + + // clang-format on + >; + +void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 708b37c42..6b0ad99d6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -12,6 +12,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance # GNHWC, GKYXC, GNHWK device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + # WMMA + device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp # NHWGC, GKYXC, NHWGK device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 000000000..3904592fc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp new file mode 100644 index 000000000..0cded93d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_wmma_i8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index 9fadfe969..8d37c2888 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -215,7 +215,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification, const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); - std::cout << "xdl found " << op_ptrs.size() << " instances" << std::endl; + std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl; for(auto& op_ptr : op_ptrs) { diff --git a/profiler/src/profile_gemm_bilinear.cpp b/profiler/src/profile_gemm_bilinear.cpp index a1a48616b..4527a2fa0 100644 --- a/profiler/src/profile_gemm_bilinear.cpp +++ b/profiler/src/profile_gemm_bilinear.cpp @@ -71,6 +71,9 @@ int profile_gemm_bilinear(int argc, char* argv[]) using F16 = ck::half_t; using F32 = float; + using I8 = std::int8_t; + using I32 = std::int32_t; + using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -141,6 +144,22 @@ int profile_gemm_bilinear(int argc, char* argv[]) { return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}); } + else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::MK_KN_MN_MN) + { + return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Row{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::MK_NK_MN_MN) + { + return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Row{}, Col{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::KM_KN_MN_MN) + { + return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Col{}, Row{}, Row{}, Row{}); + } + else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::KM_NK_MN_MN) + { + return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Col{}, Col{}, Row{}, Row{}); + } else { std::cout << "this data_type & layout is not implemented" << std::endl; -- GitLab From 8f84a01237e0ea12d63ff59e08ee90b8afd5eecc Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Fri, 8 Sep 2023 18:14:57 +0200 Subject: [PATCH 008/112] Enable DPP8 GEMM on Navi3 (#892) --- .../ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp | 3 ++- include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp index 162f0690b..24393511c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp @@ -168,7 +168,8 @@ struct DeviceGemmDpp : public DeviceGemm Date: Mon, 11 Sep 2023 11:01:36 -0600 Subject: [PATCH 009/112] Add codeowners for documentation (#902) Co-authored-by: samjwu --- .github/CODEOWNERS | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000..30f0dedd8 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,6 @@ +# Documentation files +docs/* @saadrahim @LisaDelaney +*.md @saadrahim @LisaDelaney +*.rst @saadrahim @LisaDelaney +# Header directory +library/include/* @saadrahim @LisaDelaney -- GitLab From 547dbcfbc2514768e7c3b8b71bff053238cd878c Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Tue, 12 Sep 2023 17:05:23 +0200 Subject: [PATCH 010/112] Add new instances and support for small cases in DPP8 GEMM (#896) --- .../ck/tensor_operation/gpu/warp/dpp_gemm.hpp | 230 +++++++++++++++++- .../tensor_operation_instance/gpu/gemm.hpp | 32 ++- .../gpu/gemm/CMakeLists.txt | 4 + ...gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp | 16 +- ...16_f16_f16_km_kn_mn_irregular_instance.cpp | 64 +++++ ...gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp | 16 +- ...16_f16_f16_km_nk_mn_irregular_instance.cpp | 64 +++++ ...gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp | 16 +- ...16_f16_f16_mk_kn_mn_irregular_instance.cpp | 63 +++++ ...gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp | 16 +- ...16_f16_f16_mk_nk_mn_irregular_instance.cpp | 64 +++++ 11 files changed, 562 insertions(+), 23 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp index 03a4d17c9..a18443164 100644 --- a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -11,8 +11,14 @@ namespace ck { enum struct DppInstr { - dpp8_f16_16x16x2 = 0, + dpp8_f16_1x32x2 = 0, + dpp8_f16_2x16x2, + dpp8_f16_2x32x2, + dpp8_f16_4x16x2, + dpp8_f16_4x32x2, + dpp8_f16_8x16x2, dpp8_f16_8x32x2, + dpp8_f16_16x16x2, dpp8_f16_32x8x2 }; @@ -101,6 +107,36 @@ struct dpp_type } }; +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + template <> struct dpp_type { @@ -131,6 +167,156 @@ struct dpp_type } }; +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 1; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + template struct DppSelector { @@ -143,6 +329,12 @@ struct DppSelector return DppInstr::dpp8_f16_8x32x2; } + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x16x2; + } + template <> static constexpr auto GetDpp() { @@ -155,6 +347,36 @@ struct DppSelector return DppInstr::dpp8_f16_32x8x2; } + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_1x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x32x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x16x2; + } + + template <> + static constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x32x2; + } + static constexpr auto selected_dpp = dpp_type()>{}; __host__ __device__ constexpr DppSelector() @@ -191,7 +413,6 @@ struct DppSelector // in the future when the implementation is more generalized. static_assert(selected_dpp.share_a); static_assert(selected_dpp.n_per_thread == 1); - static_assert(selected_dpp.m_per_thread == selected_dpp.lanegroup_size); static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread * selected_dpp.lanegroup_size); @@ -215,11 +436,6 @@ struct DppGemm __host__ __device__ constexpr DppGemm() { - static_assert(MPerDpp == 8 || MPerDpp == 16 || MPerDpp == 32, - "MPerDpp must be either 8, 16 or 32."); - static_assert(NPerDpp == 8 || NPerDpp == 16 || NPerDpp == 32, - "NPerDpp must be either 8, 16 or 32."); - static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index b4fa74648..3ceea58c1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -23,12 +23,17 @@ void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( std::vector>>& instances); -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( std::vector>>& instances); @@ -38,12 +43,17 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( std::vector>>& instances); -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( std::vector>>& instances); @@ -53,12 +63,17 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( std::vector>>& instances); @@ -68,12 +83,17 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( DeviceGemm>>& instances); +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); + void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( std::vector>>& instances); @@ -375,6 +395,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } @@ -386,6 +407,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); @@ -398,6 +420,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); } @@ -409,6 +432,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); } diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 48dd292f0..6c76ee843 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -32,9 +32,13 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp) + list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp) endif() list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp index 80da6d4c3..19e4d875c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp @@ -36,11 +36,21 @@ using device_gemm_dpp_f16_f16_f16_km_kn_mn_instances = std::tuple< // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 4, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 64, 64, 4, 4, 16, 16, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 4, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 64, 4, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 4, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 4, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 4, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 4, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 1, 128, 64, 4, 4, 1, 32, 1, 1, S<4, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 32, 64, 4, 4, 8, 32, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 64, 4, 4, 8, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 4, 32, 64, 4, 4, 2, 32, 1, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 64, 64, 4, 4, 2, 32, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 32, 64, 4, 4, 1, 32, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 1, 64, 64, 4, 4, 1, 32, 1, 1, S<4, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 64, 4, 4, 4, 16, 2, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 32, 4, 4, 8, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 4, 32, 32, 4, 4, 4, 32, 1, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 2, 16, 32, 4, 4, 2, 16, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp new file mode 100644 index 000000000..23e3152a1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 256, 128, 128, 64, 4, 4, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 256, 64, 64, 64, 4, 4, 16, 16, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 64, 64, 64, 4, 4, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 32, 32, 64, 4, 4, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 4, 32, 64, 4, 4, 1, 32, 1, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 64, 64, 32, 4, 4, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 32, 64, 4, 4, 8, 32, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 16, 64, 4, 4, 8, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 2, 32, 64, 4, 4, 1, 32, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 32, 16, 64, 4, 4, 16, 16, 2, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 8, 16, 64, 4, 4, 4, 16, 2, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 1, 32, 64, 4, 4, 1, 32, 1, 1, S<4, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp index 97b85fd1a..b1533b01a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp @@ -36,11 +36,21 @@ using device_gemm_dpp_f16_f16_f16_km_nk_mn_instances = std::tuple< // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 4, 8, 32, 8, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 64, 64, 4, 8, 16, 16, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 4, 8, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 4, 8, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 4, 8, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 4, 8, 32, 8, 1, 4, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 1, 128, 64, 4, 8, 1, 32, 1, 1, S<4, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 32, 64, 4, 8, 8, 32, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 64, 4, 8, 8, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 4, 32, 64, 4, 8, 2, 32, 1, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 64, 64, 4, 8, 2, 32, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 32, 64, 4, 8, 1, 32, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 1, 64, 64, 4, 8, 1, 32, 1, 1, S<4, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 64, 4, 8, 4, 16, 2, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 32, 4, 8, 8, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 4, 32, 32, 4, 8, 4, 32, 1, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 2, 16, 32, 4, 8, 2, 16, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp new file mode 100644 index 000000000..a2fc5926a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 256, 128, 128, 64, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 64, 64, 64, 4, 8, 32, 8, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 32, 64, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 32, 32, 64, 4, 8, 32, 8, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 4, 32, 64, 4, 8, 1, 32, 1, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 64, 64, 32, 4, 8, 32, 8, 2, 4, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 32, 64, 4, 8, 8, 32, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 16, 64, 4, 8, 8, 16, 1, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 2, 32, 64, 4, 8, 1, 32, 1, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 32, 16, 64, 4, 8, 16, 16, 2, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 8, 16, 64, 4, 8, 4, 16, 2, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 1, 32, 64, 4, 8, 1, 32, 1, 1, S<4, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp index 370ebfccc..2afa28a45 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp @@ -35,11 +35,21 @@ using device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances = std::tuple< // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 4, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 64, 64, 8, 4, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 4, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 64, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 4, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 4, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 4, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 1, 128, 64, 8, 4, 1, 32, 1, 1, S<4, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 32, 64, 8, 4, 8, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 64, 8, 4, 8, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 4, 32, 64, 8, 4, 2, 32, 1, 1, S<4, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 64, 64, 8, 4, 2, 32, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 32, 64, 8, 4, 1, 32, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 1, 64, 64, 8, 4, 1, 32, 1, 1, S<4, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 64, 8, 4, 4, 16, 2, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 32, 8, 4, 8, 16, 1, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 4, 32, 32, 8, 4, 4, 32, 1, 1, S<4, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 2, 16, 32, 8, 4, 2, 16, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp new file mode 100644 index 000000000..508b2e8df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 256, 128, 128, 64, 8, 4, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 256, 64, 64, 64, 8, 4, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 64, 64, 64, 8, 4, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 32, 32, 64, 8, 4, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 4, 32, 64, 8, 4, 1, 32, 1, 1, S<4, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 64, 64, 32, 8, 4, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 32, 64, 8, 4, 8, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 16, 64, 8, 4, 8, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 2, 32, 64, 8, 4, 1, 32, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 32, 16, 64, 8, 4, 16, 16, 2, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 8, 16, 64, 8, 4, 4, 16, 2, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 1, 32, 64, 8, 4, 1, 32, 1, 1, S<4, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp index 605327549..242ff07cb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp @@ -36,11 +36,21 @@ using device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances = std::tuple< // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 64, 8, 8, 32, 8, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 64, 64, 8, 8, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 32, 32, 8, 8, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 64, 8, 8, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 32, 32, 32, 8, 8, 32, 8, 1, 4, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, - DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 16, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 1, 128, 64, 8, 8, 1, 32, 1, 1, S<4, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 32, 64, 8, 8, 8, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 64, 8, 8, 8, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 4, 32, 64, 8, 8, 2, 32, 1, 1, S<4, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 64, 64, 8, 8, 2, 32, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 2, 32, 64, 8, 8, 1, 32, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 1, 64, 64, 8, 8, 1, 32, 1, 1, S<4, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 64, 8, 8, 4, 16, 2, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 8, 16, 32, 8, 8, 8, 16, 1, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 4, 32, 32, 8, 8, 4, 32, 1, 1, S<4, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 32, 2, 16, 32, 8, 8, 2, 16, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp new file mode 100644 index 000000000..241fd40b7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +// clang-format off +using device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances = std::tuple< + // ########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MDpp| NDpp| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + // ########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Dpp| Dpp| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + // ########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + // ########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 256, 64, 64, 64, 8, 8, 16, 16, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 64, 64, 64, 8, 8, 32, 8, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 32, 32, 64, 8, 8, 32, 8, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 128, 4, 32, 64, 8, 8, 1, 32, 1, 1, S<4, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 64, 64, 32, 8, 8, 32, 8, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 32, 64, 8, 8, 8, 32, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 16, 16, 64, 8, 8, 8, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 64, 2, 32, 64, 8, 8, 1, 32, 1, 1, S<4, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 32, 16, 64, 8, 8, 16, 16, 2, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 8, 16, 32, 8, 8, 8, 16, 1, 1, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1>, + DeviceGemmDpp< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, MNPadding, 32, 1, 32, 64, 8, 8, 1, 32, 1, 1, S<4, 1, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 5, 1> + >; +// clang-format on + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, + device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From 56c0279bbd822b9736b4e97d73331e1e5ccb27bb Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 12 Sep 2023 08:57:12 -0700 Subject: [PATCH 011/112] clean up the workspace after every stage (#909) --- Jenkinsfile | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 668d9a613..87d757e5c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -367,8 +367,6 @@ def runCKProfiler(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { - //cmake_build(conf) - //instead of building, just unstash the ckProfiler and install it sh """ rm -rf build mkdir build @@ -663,15 +661,12 @@ pipeline { } stages{ stage("Build Docker"){ - //when { - // beforeAgent true - // expression { params.BUILD_DOCKER.toBoolean() } - //} parallel{ stage('Docker /opt/rocm'){ agent{ label rocmnode("nogpu") } steps{ buildDocker('/opt/rocm') + cleanWs() } } } @@ -693,6 +688,7 @@ pipeline { } steps{ buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) + cleanWs() } } } @@ -715,6 +711,7 @@ pipeline { } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } stage("Build CK and run Tests on MI100/MI200") @@ -730,6 +727,7 @@ pipeline { } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } stage("Build CK and run Tests on Navi21") @@ -742,10 +740,10 @@ pipeline { environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ - } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } stage("Build CK and run Tests on Navi32") @@ -756,12 +754,12 @@ pipeline { } agent{ label rocmnode("navi32") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """ - execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ - + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1101" """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() } } } @@ -784,6 +782,7 @@ pipeline { } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + cleanWs() } } stage("Run ckProfiler: gfx90a") @@ -799,6 +798,7 @@ pipeline { } steps{ runPerfTest(setup_args:setup_args, config_targets: "ckProfiler", no_reboot:true, build_type: 'Release') + cleanWs() } } } @@ -811,6 +811,7 @@ pipeline { agent { label 'mici' } steps{ process_results() + cleanWs() } } } -- GitLab From 62d4af74491c153c196237575087843792553714 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 12 Sep 2023 17:04:27 -0500 Subject: [PATCH 012/112] Refactor f8_t, add bf8_t (#792) * Refactor f8_t to add bf8_t * Add check_err impl for f8_t * Update fp8 test * Format * Revert the fix * Update vector_type implementation * Add bf8 test * Add bf8, use BitInt types * Add bf8 conversion methods * Update type_convert for fp8/bf8 * Add check_err fp8/bf8 support * Add subnorm fp8 tests * Add subnorm bf8 tests * Fix conversion * Add bf8 cmake bindings * Add macros to enable build with disabled fp8/bf8 * Remove is_native method * Update flag combination for mixed precision instances * Add more flag checks * Add another flag to a client example * Add type traits, decouple f8/bf8 casting * Clean up * Decouple fp8 and bf8 flags * Remove more redundant flags * Remove leftover comments --- CMakeLists.txt | 13 +- client_example/20_splitk_gemm/CMakeLists.txt | 6 +- example/01_gemm/CMakeLists.txt | 6 +- include/ck/config.h.in | 7 + .../element/unary_element_wise_operation.hpp | 4 + .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 10 +- include/ck/utility/amd_buffer_addressing.hpp | 18 +- include/ck/utility/amd_xdlops.hpp | 2 + include/ck/utility/data_type.hpp | 131 ++++++++++- include/ck/utility/f8_utils.hpp | 217 +++++++++--------- include/ck/utility/type_convert.hpp | 104 ++++++++- .../device_operation_instance_factory.hpp | 7 +- .../gpu/gemm_multiply_add.hpp | 4 + .../gpu/gemm_splitk.hpp | 4 + .../include/ck/library/utility/check_err.hpp | 94 ++++++++ .../gpu/gemm_multiply_add/CMakeLists.txt | 18 +- .../gpu/gemm_splitk/CMakeLists.txt | 2 +- .../profiler/profile_gemm_splitk_impl.hpp | 4 + profiler/src/profile_gemm_multiply_add.cpp | 6 +- profiler/src/profile_gemm_splitk.cpp | 6 +- test/data_type/CMakeLists.txt | 11 +- test/data_type/bf8.cpp | 158 +++++++++++++ test/data_type/{fp8.cpp => f8.cpp} | 81 +++++-- 23 files changed, 740 insertions(+), 173 deletions(-) create mode 100644 test/data_type/bf8.cpp rename test/data_type/{fp8.cpp => f8.cpp} (58%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ca54d847..a79af9cc3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,12 @@ if (DTYPES) if (DTYPES MATCHES "fp8") add_definitions(-DCK_ENABLE_FP8) set(CK_ENABLE_FP8 "ON") + add_compile_options(-Wno-bit-int-extension) + endif() + if (DTYPES MATCHES "bf8") + add_definitions(-DCK_ENABLE_BF8) + set(CK_ENABLE_BF8 "ON") + add_compile_options(-Wno-bit-int-extension) endif() if (DTYPES MATCHES "fp16") add_definitions(-DCK_ENABLE_FP16) @@ -34,8 +40,9 @@ if (DTYPES) endif() message("DTYPES macro set to ${DTYPES}") else() - add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) + add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16) set(CK_ENABLE_ALL_DTYPES "ON") + add_compile_options(-Wno-bit-int-extension) # enable fp8 and bf8 endif() if(DL_KERNELS) @@ -365,6 +372,10 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu #message("fp8 instance found!") set(add_inst 1) endif() + if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8") + #message("bf8 instance found!") + set(add_inst 1) + endif() if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") #message("fp16 instance found!") set(add_inst 1) diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a60bada47..5571ed1d7 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) -target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) + target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations) +endif() diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 3dc2a0966..5574d0900 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -69,5 +69,7 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) endif() endif() -add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) -add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) + add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) +endif() diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 13dc5da5d..174834475 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -43,6 +43,9 @@ #ifndef CK_ENABLE_FP8 #define CK_ENABLE_FP8 "ON" #endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif #ifndef CK_ENABLE_FP16 #define CK_ENABLE_FP16 "ON" #endif @@ -66,6 +69,10 @@ #cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ #endif +#ifndef CK_ENABLE_BF8 +#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@ +#endif + #ifndef CK_ENABLE_FP16 #cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ #endif diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 905908a1c..34ac08b66 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -89,6 +89,7 @@ struct PassThrough } #endif +#if defined CK_ENABLE_FP8 template <> __host__ __device__ void operator()(f8_t& y, const f8_t& x) const { @@ -118,6 +119,7 @@ struct PassThrough { y = type_convert(x); } +#endif }; struct UnaryConvert @@ -146,6 +148,7 @@ struct ConvertBF16RTN } }; +#if defined CK_ENABLE_FP8 struct ConvertF8SR { // convert to fp8 using stochastic rounding (SR) @@ -162,6 +165,7 @@ struct ConvertF8SR y = f8_convert_sr(x); } }; +#endif struct Scale { diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 814969ef4..9ee07b84a 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -456,6 +456,7 @@ struct mfma_type } }; +#if defined CK_ENABLE_FP8 template <> struct mfma_type { @@ -499,6 +500,7 @@ struct mfma_type intrin_mfma_f32_16x16x32f8f8::Run(a, b, reg_c); } }; +#endif template struct MfmaSelector @@ -640,6 +642,7 @@ struct MfmaSelector } #endif +#if defined CK_ENABLE_FP8 template <> static constexpr auto GetMfma() { @@ -651,6 +654,7 @@ struct MfmaSelector { return MfmaInstr::mfma_f32_16x16x32f8f8; } +#endif static constexpr auto selected_mfma = mfma_type()>{}; @@ -852,7 +856,11 @@ struct XdlopsGemm { static_assert(is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value, + is_same::value +#if defined CK_ENABLE_FP8 + || is_same::value +#endif + , "base base_type must be double, float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 897cb4f24..694027100 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; - +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = amd_buffer_load_impl( @@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, } else { +#endif return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 } +#endif #else +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = amd_buffer_load_impl( @@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, } else { +#endif vector_t tmp = amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0); return src_thread_element_valid ? tmp : vector_t(0); +#if defined CK_ENABLE_FP8 } #endif +#endif } // buffer_load requires: @@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; - +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = @@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t } else { +#endif amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 } +#endif #else if(dst_thread_element_valid) { +#if defined CK_ENABLE_FP8 if constexpr(is_same::value) { auto tmp = bit_cast::type::type>( @@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t } else { +#endif amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); +#if defined CK_ENABLE_FP8 } +#endif } #endif } diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index ea7755036..a80540515 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> } }; +#if defined CK_ENABLE_FP8 template struct intrin_mfma_f32_32x32x16f8f8; @@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> #endif } }; +#endif } // namespace ck #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index c240afa2b..89100577a 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -12,7 +12,12 @@ using half_t = _Float16; #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 using int4_t = _BitInt(4); #endif -using f8_t = uint8_t; +#if defined CK_ENABLE_FP8 +using f8_t = _BitInt(8); +#endif +#if defined CK_ENABLE_BF8 +using bf8_t = unsigned _BitInt(8); +#endif // vector_type template @@ -143,14 +148,24 @@ struct scalar_type }; #endif +#if defined CK_ENABLE_FP8 template <> struct scalar_type { using type = f8_t; static constexpr index_t vector_size = 1; }; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct scalar_type +{ + using type = bf8_t; + static constexpr index_t vector_size = 1; +}; +#endif -// template struct vector_type { @@ -953,12 +968,24 @@ using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; // f8 +#if defined CK_ENABLE_FP8 using f8x2_t = typename vector_type::type; using f8x4_t = typename vector_type::type; using f8x8_t = typename vector_type::type; using f8x16_t = typename vector_type::type; using f8x32_t = typename vector_type::type; using f8x64_t = typename vector_type::type; +#endif + +// bf8 +#if defined CK_ENABLE_BF8 +using bf8x2_t = typename vector_type::type; +using bf8x4_t = typename vector_type::type; +using bf8x8_t = typename vector_type::type; +using bf8x16_t = typename vector_type::type; +using bf8x32_t = typename vector_type::type; +using bf8x64_t = typename vector_type::type; +#endif template struct NumericLimits @@ -1006,21 +1033,109 @@ struct NumericLimits }; #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +#if defined CK_ENABLE_FP8 template <> struct NumericLimits { + // negative zero nan mode with exp bias = 8 static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x77; // 0b01110111 - static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_t Min() { return f8_t(binary_min); } + + __host__ __device__ static constexpr f8_t Max() { return f8_t(binary_max); } + + __host__ __device__ static constexpr f8_t Lowest() { return f8_t(binary_lowest); } + + __host__ __device__ static constexpr f8_t QuietNaN() { return f8_t(binary_qnan); } +}; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= - __host__ __device__ static constexpr f8_t Min() { return bit_cast(binary_min); } + __host__ __device__ static constexpr bf8_t Min() { return bf8_t(binary_min); } - __host__ __device__ static constexpr f8_t Max() { return bit_cast(binary_max); } + __host__ __device__ static constexpr bf8_t Max() { return bf8_t(binary_max); } - __host__ __device__ static constexpr f8_t Lowest() { return bit_cast(binary_lowest); } + __host__ __device__ static constexpr bf8_t Lowest() { return bf8_t(binary_lowest); } - __host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast(binary_qnan); } + __host__ __device__ static constexpr bf8_t QuietNaN() { return bf8_t(binary_qnan); } }; +#endif + +template +struct NumericUtils +{ +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; + +#if defined CK_ENABLE_FP8 +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; +}; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; +}; +#endif } // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index bb13f9815..5fbebb708 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" +#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 namespace ck { // fp8 rounding modes @@ -22,53 +23,38 @@ namespace ck::utils { namespace { -template -__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; - - // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + // original type exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; int exponent; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr uint8_t nan_code = 0x80; - constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000; + constexpr Y nan_code = 0x80; + constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise - typedef typename std::conditional::value, uint16_t, uint32_t>::type - T_bitwise; + using T_bitwise = typename NumericUtils::bitwise_type; T_bitwise x_bitwise = *(reinterpret_cast(&x)); // unpack the input, depends on datatype - if constexpr(is_float) - { - head = x_bitwise & 0xFF800000; - mantissa = x_bitwise & 0x7FFFFF; - exponent = (head >> type_mant) & 0xFF; - sign = head >> (type_exp + type_mant); - } - else if constexpr(is_half) - { - head = x_bitwise & 0xFC00; - mantissa = x_bitwise & 0x3FF; - exponent = (head >> type_mant) & 0x1F; - sign = head >> (type_exp + type_mant); - } - - uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant); - uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1; - constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2); + head = x_bitwise & NumericUtils::head_mask; + mantissa = x_bitwise & NumericUtils::mant_mask; + exponent = (head >> in_mant) & NumericUtils::exp_mask; + sign = head >> (in_exp + in_mant); + + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + (1 << (in_exp - 1)) - (1 << (out_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); if constexpr(negative_zero_nan) { @@ -81,22 +67,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) return signed_inf + (mantissa != 0 ? 1 : 0); } + // if input is half and output is bf8 + if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && negative_zero_nan && + exponent == 0) + { + exponent += 1; + while(mantissa < (1 << in_mant)) + { + mantissa <<= 1; + exponent -= 1; + } + mantissa &= ~(1 << in_mant); + } + // check if x is 0.0 if(x_bitwise == 0) return 0; exponent -= exp_low_cutoff - 1; if(exponent <= 0) - drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1; - mantissa += 1 << type_mant; + drop_mask = (1 << (in_mant - out_mant + 1 - exponent)) - 1; + mantissa += 1 << in_mant; // apply random number if needed mantissa += (stoch ? rng : mantissa) & drop_mask; - if(mantissa >= (2 << type_mant)) + if(mantissa >= (2 << in_mant)) { mantissa >>= 1; exponent++; } - mantissa >>= (type_mant - f8_mant); + mantissa >>= (in_mant - out_mant); // check negative exponent if(exponent <= 0) @@ -116,7 +115,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) { if(clip) { - mantissa = (1 << f8_mant) - 1; + mantissa = (1 << out_mant) - 1; exponent = max_exp; } else @@ -127,124 +126,120 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng) // check if x is 0.0 or -0.0 if(exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant)); - mantissa &= (1 << f8_mant) - 1; - return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa; + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + mantissa &= (1 << out_mant) - 1; + return (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; } -template -__host__ __device__ T run_cast_from_f8(f8_t x) +template +__host__ __device__ Y run_cast_from_f8(X x) { - // check data type - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - - // fp8 exponent/mantissa layout - constexpr int f8_exp = 4; - constexpr int f8_mant = 3; + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; // resulting type exponent/mantissa layout - constexpr int type_exp = is_half ? 5 : 8; - constexpr int type_mant = is_half ? 10 : 23; + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr uint8_t nan_code = 0x80; - T fInf, fNegInf, fNaN, fNeg0; - if constexpr(is_half) - { - constexpr uint16_t ihInf = 0x7C00; - constexpr uint16_t ihNegInf = 0xFC00; - constexpr uint16_t ihNaN = 0x7C01; - constexpr uint16_t ihNeg0 = 0x8000; - fInf = *(reinterpret_cast(&ihInf)); - fNegInf = *(reinterpret_cast(&ihNegInf)); - fNaN = *(reinterpret_cast(&ihNaN)); - fNeg0 = *(reinterpret_cast(&ihNeg0)); - } - else if constexpr(is_float) - { - constexpr uint32_t ifInf = 0x7F800000; - constexpr uint32_t ifNegInf = 0xFF800000; - constexpr uint32_t ifNaN = 0x7F800001; - constexpr uint32_t ifNeg0 = 0x80000000; - fInf = *(reinterpret_cast(&ifInf)); - fNegInf = *(reinterpret_cast(&ifNegInf)); - fNaN = *(reinterpret_cast(&ifNaN)); - fNeg0 = *(reinterpret_cast(&ifNeg0)); - } + constexpr X nan_code = 0x80; + Y Inf, NegInf, NaN, Neg0; + using T_bitwise = typename NumericUtils::bitwise_type; + + constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; + constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; + constexpr T_bitwise NaN_bitwise = NumericUtils::NaN; + constexpr T_bitwise Neg0_bitwise = NumericUtils::Neg0; + + Inf = *(reinterpret_cast(&Inf_bitwise)); + NegInf = *(reinterpret_cast(&NegInf_bitwise)); + NaN = *(reinterpret_cast(&NaN_bitwise)); + Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); // unpack the input - uint32_t sign = x >> (f8_exp + f8_mant); - uint32_t mantissa = x & ((1 << f8_mant) - 1); - int exponent = (x & 0x7F) >> f8_mant; + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; constexpr int exp_low_cutoff = - (1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); - typename std::conditional::value, uint16_t, uint32_t>::type retval; + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; if constexpr(negative_zero_nan) { if(x == nan_code) - return fNaN; + return NaN; } else { if(x == nan_code) - return fNeg0; - if(exponent == ((1 << f8_exp) - 1)) - return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && !negative_zero_nan) + { + retval = x; + retval <<= 8; + return *(reinterpret_cast(&retval)); } // subnormal input if(exponent == 0) { // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above - int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant); - mantissa <<= sh; - mantissa &= ((1 << f8_mant) - 1); - exponent += 1 - sh; + exponent++; + while(mantissa < (1 << in_mant)) + { + mantissa <<= 1; + exponent--; + } + mantissa &= ((1 << in_mant) - 1); } exponent += exp_low_cutoff - 1; - mantissa <<= type_mant - f8_mant; + mantissa <<= out_mant - in_mant; // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) if(exponent <= 0) { - mantissa |= 1 << type_mant; + mantissa |= 1 << out_mant; mantissa >>= 1 - exponent; exponent = 0; } - retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa; - return *(reinterpret_cast(&retval)); + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return *(reinterpret_cast(&retval)); } } // namespace -template -__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng) +template +__host__ __device__ Y cast_to_f8(X x, uint32_t rng) { - // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; - static_assert(is_half || is_float, "Only half and float can be casted to f8."); + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); - return run_cast_to_f8(x, rng); + return run_cast_to_f8(x, rng); } -template -__host__ __device__ T cast_from_f8(f8_t x) +template +__host__ __device__ Y cast_from_f8(X x) { // check datatype - constexpr bool is_half = std::is_same::value; - constexpr bool is_float = std::is_same::value; + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; static_assert(is_half || is_float, "only half and float are supported."); - // check if x is 0.0 - if(x == 0) - return static_cast(0); - - return run_cast_from_f8(x); + return run_cast_from_f8(x); } } // namespace ck::utils +#endif diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 65d894037..5c5447f94 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -80,6 +80,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } +#if defined CK_ENABLE_FP8 // convert fp32 to fp8 template <> inline __host__ __device__ f8_t type_convert(float x) @@ -88,8 +89,9 @@ inline __host__ __device__ f8_t type_convert(float x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8(x, + rng); } // convert fp8 to fp32 @@ -97,7 +99,7 @@ template <> inline __host__ __device__ float type_convert(f8_t x) { constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); } // convert fp16 to fp8 @@ -108,8 +110,9 @@ inline __host__ __device__ f8_t type_convert(half_t x) constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; constexpr uint32_t rng = 0; - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8( + x, rng); } // convert fp8 to fp16 @@ -117,8 +120,53 @@ template <> inline __host__ __device__ half_t type_convert(f8_t x) { constexpr bool negative_zero_nan = true; - return utils::cast_from_f8(x); + return utils::cast_from_f8(x); } +#endif + +#if defined CK_ENABLE_BF8 +// convert fp32 to bf8 +template <> +inline __host__ __device__ bf8_t type_convert(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8( + x, rng); +} + +// convert bf8 to fp32 +template <> +inline __host__ __device__ float type_convert(bf8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} + +// convert fp16 to bf8 +template <> +inline __host__ __device__ bf8_t type_convert(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return utils:: + cast_to_f8( + x, rng); +} + +// convert bf8 to fp16 +template <> +inline __host__ __device__ half_t type_convert(bf8_t x) +{ + constexpr bool negative_zero_nan = true; + return utils::cast_from_f8(x); +} +#endif // Declare a template function for bf16 conversion using RTN template @@ -181,6 +229,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(h template __host__ __device__ constexpr Y f8_convert_sr(X x); +#if defined CK_ENABLE_FP8 // convert fp32 to fp8 with stochastic rounding template <> inline __host__ __device__ f8_t f8_convert_sr(float x) @@ -191,8 +240,9 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) constexpr int seed = 42; // as thread id is not available on host, use 0 for prn generation uint32_t rng = prand_generator(reinterpret_cast(&x), x); - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8(x, + rng); } // convert fp16 to fp8 with stochastic rounding @@ -205,8 +255,42 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) constexpr int seed = 42; // as thread id is not available on host, use 0 for prn generation uint32_t rng = prand_generator(reinterpret_cast(&x), x); - return utils::cast_to_f8( - x, rng); + return utils:: + cast_to_f8( + x, rng); +} +#endif + +#if defined CK_ENABLE_BF8 +// convert fp32 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_t f8_convert_sr(float x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils:: + cast_to_f8( + x, rng); +} + +// convert fp16 to bf8 with stochastic rounding +template <> +inline __host__ __device__ bf8_t f8_convert_sr(half_t x) +{ + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; + constexpr int seed = 42; + // as thread id is not available on host, use 0 for prn generation + uint32_t rng = prand_generator(reinterpret_cast(&x), x); + return utils:: + cast_to_f8( + x, rng); } +#endif } // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 84d31ce26..ea11fd2e1 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -17,10 +17,15 @@ namespace instance { using F64 = double; using F32 = float; using F16 = ck::half_t; -using F8 = ck::f8_t; using BF16 = ck::bhalf_t; using I8 = int8_t; using I32 = int32_t; +#if defined CK_ENABLE_FP8 +using F8 = ck::f8_t; +#endif +#if defined CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif using Empty_Tuple = ck::Tuple<>; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp index badc06dd6..64c74d479 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_add.hpp @@ -45,6 +45,7 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_ PassThrough, MultiplyAdd>>>&); +#if defined CK_ENABLE_FP8 void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances( std::vector>>&); +#endif // GEMM + Multiply + Add template && is_same_v && is_same_v && is_same_v && is_same_v) @@ -150,6 +153,7 @@ struct DeviceOperationInstanceFactory>>& instances); +#if defined CK_ENABLE_FP8 void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances( std::vector>>& @@ -96,6 +97,7 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances( std::vector>>& instances); +#endif template && is_same_v && is_same_v) { @@ -224,6 +227,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs); } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 8a7263137..c0f9ba2ed 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -230,5 +230,99 @@ check_err(const Range& out, return res; } +#if defined CK_ENABLE_FP8 +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} +#endif + +#if defined CK_ENABLE_BF8 +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, bf8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} +#endif + } // namespace utils } // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index a4f744345..36bd6a4aa 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,7 +1,13 @@ -add_instance_library(device_gemm_multiply_add_instance - device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp - device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp +set(GEMM_MULTIPLY_ADD_INSTANCES) - device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp - device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp -) +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp) +endif() + +if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp) + list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp) +endif() + +add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 89dfa8f2e..043b28a1b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -14,7 +14,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp) endif() -if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) +if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp) list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp) diff --git a/profiler/include/profiler/profile_gemm_splitk_impl.hpp b/profiler/include/profiler/profile_gemm_splitk_impl.hpp index fb68bb881..495513f66 100644 --- a/profiler/include/profiler/profile_gemm_splitk_impl.hpp +++ b/profiler/include/profiler/profile_gemm_splitk_impl.hpp @@ -214,6 +214,7 @@ bool profile_gemm_splitk_impl(int do_verification, << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl; +#if defined CK_ENABLE_FP8 // set softer tolerances for fp8 if constexpr(is_same_v || is_same_v || is_same_v) @@ -226,8 +227,11 @@ bool profile_gemm_splitk_impl(int do_verification, } else { +#endif pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 } +#endif if(tflops > best_tflops) { diff --git a/profiler/src/profile_gemm_multiply_add.cpp b/profiler/src/profile_gemm_multiply_add.cpp index fd1f5c65c..98973b2f0 100644 --- a/profiler/src/profile_gemm_multiply_add.cpp +++ b/profiler/src/profile_gemm_multiply_add.cpp @@ -59,9 +59,11 @@ int profile_gemm_multiply_add(int argc, char* argv[]) const int StrideD1 = std::stoi(argv[14]); const int StrideE = std::stoi(argv[15]); - using F8 = ck::f8_t; using F16 = ck::half_t; using F32 = float; +#if defined CK_ENABLE_FP8 + using F8 = ck::f8_t; +#endif using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -132,6 +134,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) { return profile(F16{}, F16{}, F32{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); } +#if defined CK_ENABLE_FP8 else if(data_type == MatrixDataType::F16_F8_F32_F32_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN) { @@ -142,6 +145,7 @@ int profile_gemm_multiply_add(int argc, char* argv[]) { return profile(F16{}, F8{}, F32{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{}); } +#endif else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/profiler/src/profile_gemm_splitk.cpp b/profiler/src/profile_gemm_splitk.cpp index 617e0b9cd..9c805fc1d 100644 --- a/profiler/src/profile_gemm_splitk.cpp +++ b/profiler/src/profile_gemm_splitk.cpp @@ -67,7 +67,9 @@ int profile_gemm_splitk(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; - using F8 = ck::f8_t; +#if defined CK_ENABLE_FP8 + using F8 = ck::f8_t; +#endif using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -146,6 +148,7 @@ int profile_gemm_splitk(int argc, char* argv[]) { return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}); } +#if defined CK_ENABLE_FP8 else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); @@ -178,6 +181,7 @@ int profile_gemm_splitk(int argc, char* argv[]) { return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}); } +#endif else { std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 2b63727f1..baf017455 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -3,5 +3,12 @@ if (USE_BITINT_EXTENSION_INT4) target_link_libraries(test_int4 PRIVATE utility) endif() -add_gtest_executable(test_fp8 fp8.cpp) -target_link_libraries(test_fp8 PRIVATE utility) +if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) + add_gtest_executable(test_f8 f8.cpp) + target_link_libraries(test_f8 PRIVATE utility) +endif() + +if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) + add_gtest_executable(test_bf8 bf8.cpp) + target_link_libraries(test_bf8 PRIVATE utility) +endif() diff --git a/test/data_type/bf8.cpp b/test/data_type/bf8.cpp new file mode 100644 index 000000000..6a5fa281e --- /dev/null +++ b/test/data_type/bf8.cpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::bf8_t; +using ck::f8_convert_sr; +using ck::half_t; +using ck::type_convert; + +TEST(BF8, NumericLimits) +{ + // constants given for negative zero nan mode + EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x04)); + EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); + EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); +} + +TEST(BF8, ConvertFP32Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to bf8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(type_convert(0.0f)), abs_tol); + // convert minimal float to bf8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(type_convert(std::numeric_limits::min())), + abs_tol); + // convert maximal bf8_t to float and check if equal to 57344.0 + ASSERT_NEAR(57344.0f, type_convert(type_convert(57344.0f)), abs_tol); + // convert maximal float to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(57344.0f, + type_convert(type_convert(std::numeric_limits::max())), + abs_tol); + // convert inf float to bf8_t and check if it is qNan + ASSERT_NEAR(type_convert(0x80), + type_convert(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to bf8 and back, check if holds + float pos_float = 0.0000762939f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative norm float value to bf8 and back, check if holds + float neg_float = -0.0000610351f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); + // positive subnorm float value to bf8 and back, check if holds + pos_float = 0.0000305175f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative subnorm float value to bf8 and back, check if holds + neg_float = -0.0000152587f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); +} + +TEST(BF8, ConvertFP32Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-6; + // convert 0 float to bf8 and back, check if holds + ASSERT_NEAR(0.0f, type_convert(f8_convert_sr(0.0f)), abs_tol); + // convert minimal float to bf8 and back, check if holds + ASSERT_NEAR(std::numeric_limits::min(), + type_convert(f8_convert_sr(std::numeric_limits::min())), + abs_tol); + // convert maximal bf8_t to float and check if equal to 57344.0 + ASSERT_NEAR(57344.0f, type_convert(f8_convert_sr(57344.0f)), abs_tol); + // convert maximal float to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(57344.0f, + type_convert(f8_convert_sr(std::numeric_limits::max())), + abs_tol); + // convert inf float to bf8_t and check if it is qNan + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to bf8 and back, check if holds + float pos_float = 0.0000762939f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative norm float value to bf8 and back, check if holds + float neg_float = -0.0000610351f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + // positive subnorm float value to bf8 and back, check if holds + pos_float = 0.0000305175f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative subnorm float value to bf8 and back, check if holds + neg_float = -0.0000152587f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); +} + +TEST(BF8, ConvertFP16Nearest) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to bf8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(type_convert(half_t{0.0})), abs_tol); + // convert minimal fp16 to bf8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(type_convert(ck::NumericLimits::Min())), + abs_tol); + // convert maximal bf8_t to fp16 and check if equal to 57344.0 + ASSERT_NEAR( + half_t{57344.0}, type_convert(type_convert(half_t{57344.0})), abs_tol); + // convert maximal fp16 to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(half_t{57344.0}, + type_convert(type_convert(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN + ASSERT_NEAR(type_convert(0x80), + type_convert(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to bf8 and back, check if holds + half_t pos_half = half_t{0.0000762939}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative norm fp16 value to bf8 and back, check if holds + half_t neg_half = half_t{-0.0000610351}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); + // positive subnorm fp16 value to bf8 and back, check if holds + pos_half = half_t{0.0000305175}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative subnorm fp16 value to bf8 and back, check if holds + neg_half = half_t{-0.0000152587}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); +} + +TEST(BF8, ConvertFP16Stochastic) +{ + // fix the tolerance value + float abs_tol = 1e-3; + // convert 0 fp16 to bf8 and back, check if holds + ASSERT_NEAR(half_t{0.0}, type_convert(f8_convert_sr(half_t{0.0})), abs_tol); + // convert minimal fp16 to bf8 and back, check if holds + ASSERT_NEAR(ck::NumericLimits::Min(), + type_convert(f8_convert_sr(ck::NumericLimits::Min())), + abs_tol); + // convert maximal bf8_t to fp16 and check if equal to 57344.0 + ASSERT_NEAR( + half_t{57344.0}, type_convert(f8_convert_sr(half_t{57344.0})), abs_tol); + // convert maximal fp16 to bf8 and back, check if clipped to 57344.0 + ASSERT_NEAR(half_t{57344.0}, + type_convert(f8_convert_sr(ck::NumericLimits::Max())), + abs_tol); + // convert QuietNaN fp16 to bf8_t and check if it is QuietNaN + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to bf8 and back, check if holds + half_t pos_half = half_t{0.0000762939}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative norm fp16 value to bf8 and back, check if holds + half_t neg_half = half_t{-0.0000610351}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + // positive subnorm fp16 value to bf8 and back, check if holds + pos_half = half_t{0.0000305175}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative subnorm fp16 value to bf8 and back, check if holds + neg_half = half_t{-0.0000152587}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); +} diff --git a/test/data_type/fp8.cpp b/test/data_type/f8.cpp similarity index 58% rename from test/data_type/fp8.cpp rename to test/data_type/f8.cpp index 5004fe952..0612a1cf4 100644 --- a/test/data_type/fp8.cpp +++ b/test/data_type/f8.cpp @@ -12,10 +12,11 @@ using ck::type_convert; TEST(FP8, NumericLimits) { - EXPECT_EQ(ck::NumericLimits::Min(), 0x08); - EXPECT_EQ(ck::NumericLimits::Max(), 0x77); - EXPECT_EQ(ck::NumericLimits::Lowest(), 0xF7); - EXPECT_EQ(ck::NumericLimits::QuietNaN(), 0x80); + // constants given for negative zero nan mode + EXPECT_EQ(ck::NumericLimits::Min(), type_convert(0x08)); + EXPECT_EQ(ck::NumericLimits::Max(), type_convert(0x7F)); + EXPECT_EQ(ck::NumericLimits::Lowest(), type_convert(0xFF)); + EXPECT_EQ(ck::NumericLimits::QuietNaN(), type_convert(0x80)); } TEST(FP8, ConvertFP32Nearest) @@ -35,12 +36,20 @@ TEST(FP8, ConvertFP32Nearest) type_convert(type_convert(std::numeric_limits::max())), abs_tol); // convert inf float to f8_t and check if it is qNan - ASSERT_NEAR(0x80, type_convert(std::numeric_limits::infinity()), abs_tol); - // positive float value to fp8 and back, check if holds - float pos_float = 0.0078125f; + ASSERT_NEAR(type_convert(0x80), + type_convert(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to fp8 and back, check if holds + float pos_float = 0.017578125f; + ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); + // negative norm float value to fp8 and back, check if holds + float neg_float = -0.015625f; + ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); + // positive subnorm float value to fp8 and back, check if holds + pos_float = 0.00390625f; ASSERT_NEAR(pos_float, type_convert(type_convert(pos_float)), abs_tol); - // negative float value to fp8 and back, check if holds - float neg_float = -0.0156250f; + // negative subnorm float value to fp8 and back, check if holds + neg_float = -0.001953125f; ASSERT_NEAR(neg_float, type_convert(type_convert(neg_float)), abs_tol); } @@ -61,12 +70,20 @@ TEST(FP8, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to f8_t and check if it is qNan - ASSERT_NEAR(0x80, f8_convert_sr(std::numeric_limits::infinity()), abs_tol); - // positive float value to fp8 and back, check if holds - float pos_float = 0.0078125f; + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(std::numeric_limits::infinity()), + abs_tol); + // positive norm float value to fp8 and back, check if holds + float pos_float = 0.017578125f; + ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); + // negative norm float value to fp8 and back, check if holds + float neg_float = -0.015625f; + ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); + // positive subnorm float value to fp8 and back, check if holds + pos_float = 0.00390625f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); - // negative float value to fp8 and back, check if holds - float neg_float = -0.0156250f; + // negative subnorm float value to fp8 and back, check if holds + neg_float = -0.001953125f; ASSERT_NEAR(neg_float, type_convert(f8_convert_sr(neg_float)), abs_tol); } @@ -87,12 +104,20 @@ TEST(FP8, ConvertFP16Nearest) type_convert(type_convert(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_t and check if it is QuietNaN - ASSERT_NEAR(0x80, type_convert(ck::NumericLimits::QuietNaN()), abs_tol); - // positive fp16 value to fp8 and back, check if holds - half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(type_convert(0x80), + type_convert(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.017578125}; + ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); + // negative norm fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.015625}; + ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); + // positive subnorm fp16 value to fp8 and back, check if holds + pos_half = half_t{0.00390625}; ASSERT_NEAR(pos_half, type_convert(type_convert(pos_half)), abs_tol); - // negative fp16 value to fp8 and back, check if holds - half_t neg_half = half_t{-0.0156250}; + // negative subnorm fp16 value to fp8 and back, check if holds + neg_half = half_t{-0.001953125}; ASSERT_NEAR(neg_half, type_convert(type_convert(neg_half)), abs_tol); } @@ -113,11 +138,19 @@ TEST(FP8, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_t and check if it is QuietNaN - ASSERT_NEAR(0x80, f8_convert_sr(ck::NumericLimits::QuietNaN()), abs_tol); - // positive fp16 value to fp8 and back, check if holds - half_t pos_half = half_t{0.0078125}; + ASSERT_NEAR(type_convert(0x80), + f8_convert_sr(ck::NumericLimits::QuietNaN()), + abs_tol); + // positive norm fp16 value to fp8 and back, check if holds + half_t pos_half = half_t{0.017578125}; + ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); + // negative norm fp16 value to fp8 and back, check if holds + half_t neg_half = half_t{-0.015625}; + ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); + // positive subnorm fp16 value to fp8 and back, check if holds + pos_half = half_t{0.00390625}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); - // negative fp16 value to fp8 and back, check if holds - half_t neg_half = half_t{-0.0156250}; + // negative subnorm fp16 value to fp8 and back, check if holds + neg_half = half_t{-0.001953125}; ASSERT_NEAR(neg_half, type_convert(f8_convert_sr(neg_half)), abs_tol); } -- GitLab From 74d32f0719eeaa2ebdc465bc6c1ed525f785c30c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 12 Sep 2023 18:14:33 -0700 Subject: [PATCH 013/112] Add a switch to build DL kernels and build them with staging compiler. (#907) * enable building DL kernels with the daily staging compiler * move the DL_KERNELS flag to another function --- Jenkinsfile | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index 87d757e5c..c1a28a74e 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -210,6 +210,9 @@ def cmake_build(Map conf=[:]){ } else{ setup_args = ' -DBUILD_DEV=On' + setup_args } + if (params.DL_KERNELS){ + setup_args = setup_args + " -DDL_KERNELS=ON " + } if(build_type_debug){ setup_args = " -DCMAKE_BUILD_TYPE=debug -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'" + setup_args @@ -612,7 +615,7 @@ def process_results(Map conf=[:]){ //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1 0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT= - 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" + 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" pipeline { agent none @@ -647,6 +650,10 @@ pipeline { name: "RUN_FULL_QA", defaultValue: false, description: "Select whether to run small set of performance tests (default) or full QA") + booleanParam( + name: "DL_KERNELS", + defaultValue: false, + description: "Select whether to build DL kernels (default: OFF)") } environment{ dbuser = "${dbuser}" -- GitLab From a66d14edf25d34cf2c4b021f8c27f12e5267d566 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 12 Sep 2023 22:17:56 -0500 Subject: [PATCH 014/112] fixed fp8 issues (#894) * fixed fp8 init; and reference gemm * Update host_tensor_generator.hpp * fixed convert * fixed reference gemm * fixed comments * fixed comments * fixed ci * fixed computeType --------- Co-authored-by: Jing Zhang --- .../groupnorm_sigmoid_mul_fp16.cpp | 16 ++++++----- .../device_gemm_multiple_d_xdl_cshuffle.hpp | 7 +++-- .../element/unary_element_wise_operation.hpp | 27 +++++++++++++++---- .../threadwise_tensor_slice_transfer.hpp | 9 +++---- .../cpu/reference_gemm.hpp | 7 ++--- .../profile_gemm_multiply_add_impl.hpp | 4 +-- 6 files changed, 45 insertions(+), 25 deletions(-) diff --git a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp index cc107b63d..b36bd761b 100644 --- a/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp +++ b/example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp @@ -14,18 +14,22 @@ using ComputeDataType = float; struct YElementOp { - template - __host__ __device__ void operator()(T& y, const T& x) const + template + __host__ __device__ void operator()(Y& y, const X& x) const { - static_assert(ck::is_same::value || ck::is_same::value || - ck::is_same::value, + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, "Data type is not supported by this operation!"); - T a; + static_assert(ck::is_same::value || ck::is_same::value || + ck::is_same::value, + "Data type is not supported by this operation!"); + + X a; ck::tensor_operation::element_wise::Sigmoid{}(a, x); - y = x * a; + y = ck::type_convert(x * a); }; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index c90c28f5a..d98725cf9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -144,7 +144,8 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename ComputeDataType = EDataType> struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); - using ComputeDataType = EDataType; - // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< - ADataType, // TODO: distinguish A/B datatype + ADataType, BDataType, ComputeDataType, AccDataType, diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 34ac08b66..69a6540a0 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -27,6 +27,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(float& y, const double& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(float& y, const float& x) const { @@ -81,6 +87,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(int8_t& y, const float& x) const + { + y = type_convert(x); + } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> __host__ __device__ void operator()(int4_t& y, const int4_t& x) const @@ -416,14 +428,19 @@ struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} - template - __host__ __device__ void operator()(T& y, const T& x) const + template + __host__ __device__ void operator()(Y& y, const X& x) const { - static_assert(is_same::value || is_same::value || - is_same::value, + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, "Data type is not supported by this operation!"); - y = x / (ck::type_convert(1) + ck::math::exp(-beta_ * x)); + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + ck::math::exp(bx))); }; float beta_ = 1.0f; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 605f2569c..277421407 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); - // apply type convert - dst_vector.template AsType()(i) = type_convert(v); + dst_vector.template AsType()(i) = v; }); const bool is_dst_valid = @@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic constexpr index_t dst_offset = dst_desc.CalculateOffset( dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_buf[Number{}]); // apply type convert - dst_buf(Number{}) = type_convert(v); + dst_buf(Number{}) = v; }); }); } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 309b4afad..95bd1e13d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -20,7 +20,8 @@ template + typename CElementwiseOperation, + typename ComputType = ADataType> struct ReferenceGemm : public device::BaseOperator { // Argument @@ -64,8 +65,8 @@ struct ReferenceGemm : public device::BaseOperator for(int k = 0; k < K; ++k) { - ADataType v_a; - BDataType v_b; + ComputType v_a; + ComputType v_b; // use PassThrough instead of ConvertBF16RTN for reference calculation if constexpr(is_same_v{-1, 1}); break; default: - a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 0.2}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}); d0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } -- GitLab From 475188ca2ef2bfc931c2dc814c585e6248b9b463 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 13 Sep 2023 17:14:31 +0200 Subject: [PATCH 015/112] Add grouped conv bwd weight dl instances and new layout (#897) * Add grouped conv bwd weight dl instances and new layout * Add M and N padding * Remove todo comment * Enable grouped conv fwd dl k,c=1 generic instance * Comment fixes --- .../grouped_conv_bwd_weight_dl_fp16.cpp | 89 ++- .../run_grouped_conv_bwd_weight_example.inc | 14 +- ... => device_grouped_conv_bwd_weight_dl.hpp} | 693 ++++++++++-------- include/ck/utility/inner_product.hpp | 12 + ...ce_grouped_conv_bwd_weight_dl_instance.hpp | 86 +++ .../device_grouped_conv_fwd_dl_instance.hpp | 4 +- .../grouped_convolution_backward_weight.hpp | 304 ++++++++ .../grouped_conv1d_bwd_weight/CMakeLists.txt | 17 +- ...weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp | 45 ++ ..._weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp | 45 ++ ..._weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp | 45 ++ ...weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp | 45 ++ ..._weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp | 45 ++ ..._weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp | 45 ++ ...eight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp | 7 - ...weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp | 7 - ...weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp | 7 - .../grouped_conv2d_bwd_weight/CMakeLists.txt | 16 +- ...ght_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 46 ++ ...ight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 46 ++ ...ight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 46 ++ ...ght_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 46 ++ ...ight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 46 ++ ...ight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 46 ++ ...ht_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 1 - ...ght_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 1 - ...ght_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 1 - ...ht_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 1 - ...ght_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 1 - ...ght_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 1 - .../grouped_conv3d_bwd_weight/CMakeLists.txt | 17 +- ..._dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp | 45 ++ ...t_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp | 45 ++ ...t_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp | 45 ++ ..._dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 46 ++ ...t_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 46 ++ ...t_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 46 ++ ...xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp | 7 - ..._xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp | 7 - ..._xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp | 7 - ...xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 1 - ..._xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 1 - ..._xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 1 - .../test_grouped_convnd_bwd_weight.cpp | 70 +- 44 files changed, 1743 insertions(+), 449 deletions(-) rename include/ck/tensor_operation/gpu/device/impl/{device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp => device_grouped_conv_bwd_weight_dl.hpp} (63%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp index 375c309e1..cb6d68e04 100644 --- a/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp +++ b/example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_dl_fp16.cpp @@ -3,7 +3,7 @@ #include "common.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" using InDataType = F16; using WeiDataType = F16; @@ -15,44 +15,55 @@ using WeiElementOp = PassThrough; using OutElementOp = PassThrough; template -using DeviceConvBwdWeightInstance = - ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl< - NDimSpatial, // NDimSpatial - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvBwdWeightDefault, // ConvBackwardWeightSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 16, // K0PerBlock - 2, // K1 - 4, // M1PerThread - 4, // N1PerThread - 1, // KPerThread - S<8, 2>, // M1N1ThreadClusterM1Xs - S<8, 2>, // M1N1ThreadClusterN1Xs - S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 - S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder - S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder - S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 - S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 - S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 - S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder - S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder - S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder - S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 - S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - 4>; // CThreadTransferDstScalarPerVector +using DeviceConvBwdWeightInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Dl< + NDimSpatial, // NDimSpatial + ck::tuple_element_t>, // InLayout + ck::tuple_element_t>, // WeiLayout + ck::tuple_element_t>, // OutLayout + InDataType, // InDataType + WeiDataType, // WeiDataType + OutDataType, // OutDataType + AccDataType, // AccDataType + InElementOp, // InElementwiseOperation + WeiElementOp, // WeiElementwiseOperation + OutElementOp, // OutElementwiseOperation + ConvBwdWeightDefault, // ConvBackwardWeightSpecialization + 256, // BlockSize + 128, // MPerBlock + 128, // NPerBlock + 16, // K0PerBlock + 2, // K1 + 4, // M1PerThread + 4, // N1PerThread + 1, // KPerThread + S<8, 2>, // M1N1ThreadClusterM1Xs + S<8, 2>, // M1N1ThreadClusterN1Xs + S<1, 8, 1, 1, 2>, // ABlockTransferThreadSliceLengths_K0_M0_M1_K1 + S<1, 2, 1, 128, 1>, // ABlockTransferThreadClusterLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferThreadClusterArrangeOrder + S<0, 2, 3, 1, 4>, // ABlockTransferSrcAccessOrder + S<1, 1, 1, 1, 1>, // ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 + S<0, 2, 3, 1, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 1>, // ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 + S<1, 1, 1, 8, 2>, // BBlockTransferThreadSliceLengths_K0_N0_N1_K1 + S<1, 16, 1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferThreadClusterArrangeOrder + S<0, 1, 4, 2, 3>, // BBlockTransferSrcAccessOrder + S<1, 1, 1, 8, 1>, // BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 4, 2, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + S<1, 1, 1, 1, 2>, // BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 + S<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + 4>; // CThreadTransferDstScalarPerVector #include "run_grouped_conv_bwd_weight_example.inc" diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 29ce0324a..34bca7dc8 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -14,20 +14,8 @@ template bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, const ck::utils::conv::ConvParam& conv_param) { - ck::index_t split_k; - // Set split_k = 2 for xdl op, split_k = 1 for dl // Dl op doesn't support split_k > 1 - // TODO: Add Dl op split_k > 1 support - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) - { - split_k = 2; - } - else - { - split_k = 1; - } + constexpr ck::index_t split_k = 1; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp similarity index 63% rename from include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp rename to include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 198751cdf..014140c46 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -72,6 +73,9 @@ __global__ void const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \ + defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); @@ -96,9 +100,27 @@ __global__ void block_2_ctile_map, integral_constant{}, integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_kbatch_k0_m0_m1_k1; + ignore = b_grid_desc_kbatch_k0_n0_n1_k1; + ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif } template -struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl - : public DeviceGroupedConvBwdWeight< - NDimSpatial, - ck::tuple_element_t>, - ck::tuple_element_t>, - ck::tuple_element_t>, - InDataType, - WeiDataType, - OutDataType, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> +struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight { - using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl; + // 1d + static constexpr bool is_NWGK_GKXC_NWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNWK_GKXC_GNWC = + is_same_v && + is_same_v && + is_same_v; + // 2d + static constexpr bool is_NHWGK_GKYXC_NHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNHWK_GKYXC_GNHWC = + is_same_v && + is_same_v && + is_same_v; + // 3d + static constexpr bool is_NDHWGK_GKZYXC_NDHWGC = + is_same_v && + is_same_v && + is_same_v; + static constexpr bool is_GNDHWK_GKZYXC_GNDHWC = + is_same_v && + is_same_v && + is_same_v; + + using DeviceOp = DeviceGroupedConvBwdWeight_Dl; using ADataType = OutDataType; using BDataType = InDataType; @@ -176,6 +215,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + static constexpr auto spatial_offset = I3; + static constexpr auto K1Number = Number{}; static constexpr auto GemmK1Number = K1Number; @@ -195,12 +236,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -209,90 +250,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl { using namespace ck; - const index_t Wi = input_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[0]; - const index_t InLeftPadW = input_left_pads[0]; - const index_t InRightPadW = input_right_pads[0]; - const index_t ConvStrideW = conv_filter_strides[0]; - const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset]; + const index_t InLeftPadW = input_left_pads[I0]; + const index_t InRightPadW = input_right_pads[I0]; + const index_t ConvStrideW = conv_filter_strides[I0]; + const index_t ConvDilationW = conv_filter_dilations[I0]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset]; const index_t GemmKTotal = N * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weights tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); - const auto in_n_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -321,38 +374,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -361,103 +419,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl { using namespace ck; - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = filter_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1]; + + const index_t InLeftPadH = input_left_pads[I0]; + const index_t InLeftPadW = input_left_pads[I1]; + const index_t InRightPadH = input_right_pads[I0]; + const index_t InRightPadW = input_right_pads[I1]; + const index_t ConvStrideH = conv_filter_strides[I0]; + const index_t ConvStrideW = conv_filter_strides[I1]; + const index_t ConvDilationH = conv_filter_dilations[I0]; + const index_t ConvDilationW = conv_filter_dilations[I1]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1]; const index_t GemmKTotal = N * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * X * Y; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -488,39 +554,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -529,110 +600,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl { using namespace ck; - const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[1]; - const index_t Wi = input_spatial_lengths[2]; - - const index_t Do = output_spatial_lengths[0]; - const index_t Ho = output_spatial_lengths[1]; - const index_t Wo = output_spatial_lengths[2]; - - const index_t Z = filter_spatial_lengths[0]; - const index_t Y = filter_spatial_lengths[1]; - const index_t X = filter_spatial_lengths[2]; - - const index_t InLeftPadD = input_left_pads[0]; - const index_t InLeftPadH = input_left_pads[1]; - const index_t InLeftPadW = input_left_pads[2]; - - const index_t InRightPadD = input_right_pads[0]; - const index_t InRightPadH = input_right_pads[1]; - const index_t InRightPadW = input_right_pads[2]; - - const index_t ConvStrideD = conv_filter_strides[0]; - const index_t ConvStrideH = conv_filter_strides[1]; - const index_t ConvStrideW = conv_filter_strides[2]; - - const index_t ConvDilationD = conv_filter_dilations[0]; - const index_t ConvDilationH = conv_filter_dilations[1]; - const index_t ConvDilationW = conv_filter_dilations[2]; + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2]; + const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2]; + const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2]; + + const index_t InLeftPadD = input_left_pads[I0]; + const index_t InLeftPadH = input_left_pads[I1]; + const index_t InLeftPadW = input_left_pads[I2]; + const index_t InRightPadD = input_right_pads[I0]; + const index_t InRightPadH = input_right_pads[I1]; + const index_t InRightPadW = input_right_pads[I2]; + const index_t ConvStrideD = conv_filter_strides[I0]; + const index_t ConvStrideH = conv_filter_strides[I1]; + const index_t ConvStrideW = conv_filter_strides[I2]; + const index_t ConvDilationD = conv_filter_dilations[I0]; + const index_t ConvDilationH = conv_filter_dilations[I1]; + const index_t ConvDilationW = conv_filter_dilations[I2]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InDStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2]; const index_t GemmKTotal = N * Do * Ho * Wo; - const index_t GemmM = K; - const index_t GemmN = C * Z * X * Y; - const index_t GemmKBatch = batch_k; const index_t GemmK0 = math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * K0PerBlock; - const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { // A: output tensor - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // B: input tensor - const auto in_gemmktotal_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride)); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } else { - const auto out_gemmktotal_gemmm_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - const auto in_n_di_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride)); // A: output tensor - const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( - out_gemmktotal_gemmm_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmM)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_gemmkpad_gemmm_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmM)), + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); @@ -672,27 +753,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( - in_gemmktotal_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmkpad_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_gemmmpad_gemmnpad_grid_desc); } } // function end @@ -701,22 +787,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static auto GetABCGridDesc() { return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( - 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); } template ::type = false> static auto GetABCGridDesc() { return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); } template ::type = false> static auto GetABCGridDesc() { - return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, - 1, - 1, + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, @@ -785,11 +871,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl WeiDataType* p_wei_grid, const OutDataType* p_out_grid, const std::array& a_g_n_c_wis_lengths, // input - const std::array& /*a_g_n_c_wis_strides*/, + const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, // weight - const std::array& /*b_g_k_c_xs_strides*/, + const std::array& b_g_k_c_xs_strides, const std::array& e_g_n_k_wos_lengths, // output - const std::array& /*e_g_n_k_wos_strides*/, + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -809,38 +895,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl a_element_op_{out_element_op}, b_element_op_{wei_element_op}, c_element_op_{in_element_op}, - Conv_G_{a_g_n_c_wis_lengths[0]}, - Conv_N_{a_g_n_c_wis_lengths[1]}, - Conv_K_{b_g_k_c_xs_lengths[1]}, - Conv_C_{a_g_n_c_wis_lengths[2]}, - input_spatial_lengths_{}, - filter_spatial_lengths_{}, - output_spatial_lengths_{}, + Conv_G_{a_g_n_c_wis_lengths[I0]}, + Conv_K_{b_g_k_c_xs_lengths[I1]}, + Conv_C_{a_g_n_c_wis_lengths[I2]}, + filter_lengths_{b_g_k_c_xs_lengths}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} { - constexpr index_t spatial_offset = 3; - std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, - end(a_g_n_c_wis_lengths), - begin(input_spatial_lengths_)); - std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, - end(b_g_k_c_xs_lengths), - begin(filter_spatial_lengths_)); - std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, - end(e_g_n_k_wos_lengths), - begin(output_spatial_lengths_)); - const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - Conv_N_, - Conv_K_, - Conv_C_, - input_spatial_lengths_, - filter_spatial_lengths_, - output_spatial_lengths_, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -863,24 +935,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = - Conv_N_ * Conv_K_ * - std::accumulate(begin(output_spatial_lengths_), - end(output_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideB_ = - Conv_N_ * Conv_C_ * - std::accumulate(begin(input_spatial_lengths_), - end(input_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); - compute_ptr_offset_of_batch_.BatchStrideC_ = - Conv_K_ * Conv_C_ * - std::accumulate(begin(filter_spatial_lengths_), - end(filter_spatial_lengths_), - index_t{1}, - std::multiplies<>{}); + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0]; } const ADataType* p_a_grid_; @@ -908,13 +965,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl // for checking IsSupportedArgument() const index_t Conv_G_; - const index_t Conv_N_; const index_t Conv_K_; const index_t Conv_C_; - std::array input_spatial_lengths_; - std::array filter_spatial_lengths_; - std::array output_spatial_lengths_; + std::array filter_lengths_; const std::array& conv_filter_strides_; const std::array& conv_filter_dilations_; const std::array& input_left_pads_; @@ -1036,10 +1090,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl static bool IsSupportedArgument(const Argument& arg) { - // check device - if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || - ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || - ck::get_device_name() == "gfx1102")) + + // DL version only supports split_k equal to 1 + if(arg.k_batch_ != 1) + return false; + + if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) || + (NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) || + (NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC)))) { return false; } @@ -1050,8 +1108,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl // check if it's 1x1, stride=1 pad = 0 conv for(int i = 0; i < NDimSpatial; i++) { - if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && - arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + if(!(arg.filter_lengths_[spatial_offset + i] == 1 && + arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) { return false; } @@ -1206,7 +1265,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl auto str = std::stringstream(); // clang-format off - str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl" + str << "DeviceGroupedConvBwdWeight_Dl" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/include/ck/utility/inner_product.hpp b/include/ck/utility/inner_product.hpp index b58b2b331..16c139c00 100644 --- a/include/ck/utility/inner_product.hpp +++ b/include/ck/utility/inner_product.hpp @@ -72,6 +72,18 @@ inner_product(const float4_t& a, const float4_t& b, f c); } +template <> +__device__ void inner_product(const bhalf_t& a, const bhalf_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + +template <> +__device__ void inner_product(const half_t& a, const half_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + template <> __device__ void inner_product(const half2_t& a, const half2_t& b, float& c) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp new file mode 100644 index 000000000..29fe298e4 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using namespace ck::tensor_layout::convolution; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdWeightDefault = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default; + +static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 = + ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_weight_dl_f32_instances = std::tuple< + // clang-format off + //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| + //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_dl_f16_instances = std::tuple< + // clang-format off + //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| + //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_dl_bf16_instances = std::tuple< + // clang-format off + //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| + //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp index 86ff43e5c..1ebd0c3d6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp @@ -55,8 +55,8 @@ using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple< // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instances // TODO: Change to ScalarPerVector = 1 when inner_product will be supported - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 2, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 1, 1, 2>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 2>, + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 1, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 1>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>, DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index cad3e1ace..c7a694b4c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -17,6 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { +// xdl // conv1d backward weight void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( std::vector>>& instances); +#ifdef DL_KERNELS +// dl +// conv1d backward weight +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( + std::vector>>& instances); + +// conv2d backward weight +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + +// conv3d backward weight +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + template && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); +#endif add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); +#endif add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + op_ptrs); +#endif add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( op_ptrs); } } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef DL_KERNELS + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } } else if constexpr(NumDimSpatial == 2) { @@ -266,12 +522,20 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + op_ptrs); +#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); +#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( op_ptrs); } @@ -279,6 +543,10 @@ struct DeviceOperationInstanceFactory && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + op_ptrs); +#endif add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( op_ptrs); } @@ -289,12 +557,20 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + op_ptrs); +#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); +#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); } @@ -302,6 +578,10 @@ struct DeviceOperationInstanceFactory && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + op_ptrs); +#endif add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( op_ptrs); } @@ -315,12 +595,20 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + op_ptrs); +#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); +#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( op_ptrs); } @@ -328,6 +616,10 @@ struct DeviceOperationInstanceFactory && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + op_ptrs); +#endif add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( op_ptrs); } @@ -338,12 +630,20 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); +#endif add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); +#endif add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); } @@ -351,6 +651,10 @@ struct DeviceOperationInstanceFactory && is_same_v) { +#ifdef DL_KERNELS + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + op_ptrs); +#endif add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( op_ptrs); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index 3808e0248..660f2544a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,5 +1,16 @@ -add_instance_library(device_grouped_conv1d_bwd_weight_instance +set(GROUPED_CONV1D_BWD_WEIGHT device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp - device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp -) + device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp) + +if(DL_KERNELS) + list(APPEND GROUPED_CONV1D_BWD_WEIGHT + device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp + device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp + device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp + device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp + device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp + device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp) +endif() + +add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp new file mode 100644 index 000000000..d7a15784a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp new file mode 100644 index 000000000..6f891ef94 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp new file mode 100644 index 000000000..2db7fbdb9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<1, + GNWC, + GKXC, + GNWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp new file mode 100644 index 000000000..a92cb4285 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<1, + NWGC, + GKXC, + NWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<1, + NWGC, + GKXC, + NWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp new file mode 100644 index 000000000..713d70f37 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<1, + NWGC, + GKXC, + NWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<1, + NWGC, + GKXC, + NWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp new file mode 100644 index 000000000..05b29af74 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<1, + NWGC, + GKXC, + NWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<1, + NWGC, + GKXC, + NWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp index 25ea4f48c..f9368ab57 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp @@ -1,13 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp index 8065012f1..edba11efb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp @@ -1,13 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp index c70a54c2d..5f28994f0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp @@ -1,13 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index b7b9fc92d..578e9029e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,9 +1,19 @@ -add_instance_library(device_grouped_conv2d_bwd_weight_instance +set(GROUPED_CONV2D_BWD_WEIGHT device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp -) + device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) +if(DL_KERNELS) + list(APPEND GROUPED_CONV2D_BWD_WEIGHT + device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp) +endif() + +add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp new file mode 100644 index 000000000..37b465e6c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 000000000..64b38655e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp new file mode 100644 index 000000000..2cc6527df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<2, + GNHWC, + GKYXC, + GNHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 000000000..cf3db8331 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 000000000..eb592b837 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 000000000..ed7ab6f4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index cf39c8601..17f5ee4e2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index d52f0b4d1..abb90b7c7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 62547a501..2e0fef9cf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp index 1cb9991a8..614cc0a7e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index e64d55c3b..f6e1ada35 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 5fa4c9ba3..384706414 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 5118599b4..6bd9d6e64 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,8 +1,19 @@ -add_instance_library(device_grouped_conv3d_bwd_weight_instance +set(GROUPED_CONV3D_BWD_WEIGHT device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp - device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp -) + device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) + +if(DL_KERNELS) + list(APPEND GROUPED_CONV3D_BWD_WEIGHT + device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp + device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp + device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp + device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp) +endif() + +add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp new file mode 100644 index 000000000..c9646d085 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 000000000..4eb54d0f9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp new file mode 100644 index 000000000..c22825cce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<3, + GNDHWC, + GKZYXC, + GNDHWK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 000000000..a37e6cbf3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_bf16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 000000000..b548c6fd5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f16_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 000000000..10d708a3f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_dl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_dl_f32_instances<3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp index c8f456db8..91d80e4f7 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp @@ -1,13 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp index 099123ecb..d03f0a7ba 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -1,13 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp index 0eda980b3..7c24cc8fd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp @@ -1,13 +1,6 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp index 1e5c1946f..a394e0d6f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp index e0b442bf2..eba721c7b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp index 4bb794824..7dd289139 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -2,7 +2,6 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - #include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp" namespace ck { diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index ae15f6c0c..bf95727fb 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -14,6 +14,8 @@ #include "profiler/profile_grouped_conv_bwd_weight_impl.hpp" +using namespace ck::tensor_layout::convolution; + template class TestGroupedConvndBwdWeight : public ::testing::Test { @@ -27,28 +29,59 @@ class TestGroupedConvndBwdWeight : public ::testing::Test using NDimSpatial = std::tuple_element_t<6, Tuple>; std::vector conv_params; - ck::index_t split_k{2}; + std::vector split_ks{1, 2}; + + bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) + { + // K or C are odd is supported only by DL kernel (only applies to fp16) + // DL kernel is only supported for split_k=1 + if constexpr(std::is_same_v) + { + if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) + { + return true; + } + } + + // 1d NWGC is only supported by DL kernel + // DL kernel is only supported for split_k=1 + if constexpr(std::is_same_v && std::is_same_v) + { + if(split_k != 1) + { + return true; + } + } + + return false; + } void Run() { EXPECT_FALSE(conv_params.empty()); bool pass = true; - for(auto& param : conv_params) + for(auto split_k : split_ks) { - pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl( - true, // do_verification - 1, // init_method: integer value - false, // do_log - false, // time_kernel - param, - split_k); + for(auto& param : conv_params) + { + if(!skip_case(param, split_k)) + { + pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param, + split_k); + } + } } EXPECT_TRUE(pass); } @@ -69,12 +102,13 @@ class TestGroupedConvndBwdWeight3d : public TestGroupedConvndBwdWeight { }; -using namespace ck::tensor_layout::convolution; - using KernelTypes1d = ::testing::Types< std::tuple>, std::tuple>, - std::tuple>>; + std::tuple>, + std::tuple>, + std::tuple>, + std::tuple>>; using KernelTypes2d = ::testing::Types< std::tuple>, std::tuple>, -- GitLab From 5fe687fa270aeb9e387278a52ea4f09d997c4b6b Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Wed, 13 Sep 2023 14:38:12 -0700 Subject: [PATCH 016/112] [Cmake] Set cmake default build type Release and path to /opt/rocm (#914) --- CMakeLists.txt | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index a79af9cc3..e3b60347d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,18 @@ cmake_minimum_required(VERSION 3.14) +# This has to be initialized before the project() command appears +# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE +if( NOT MSVC_IDE AND NOT CMAKE_BUILD_TYPE ) + set( CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build, options are: None Debug Release RelWithDebInfo MinSizeRel." ) +endif() + +# Default installation path +if(WIN32) + set(CMAKE_INSTALL_PREFIX "/opt/rocm/x86_64-w64-mingw32" CACHE PATH "") +else() + set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "") +endif() + set(version 1.1.0) # Check support for CUDA/HIP in Cmake project(composable_kernel VERSION ${version}) -- GitLab From 0d8efaa13d2cb93e6b1cd0ae702bf15778d1d4ec Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 14 Sep 2023 09:36:26 -0700 Subject: [PATCH 017/112] change the cmake update method (#918) --- Dockerfile | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index e479268f4..b0954fdc7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -73,15 +73,8 @@ RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releas RUN gunzip /usr/local/bin/ninja.gz RUN chmod a+x /usr/local/bin/ninja RUN git clone https://github.com/nico/ninjatracing.git -RUN apt purge --auto-remove -y cmake -RUN apt update -RUN apt install -y software-properties-common lsb-release -RUN apt clean all -RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null -RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" -RUN apt install -y kitware-archive-keyring -RUN rm /etc/apt/trusted.gpg.d/kitware.gpg -RUN apt install -y cmake +# Update the cmake to the latest version +RUN pip install --upgrade cmake # Setup ubsan environment to printstacktrace RUN ln -s /usr/bin/llvm-symbolizer-3.8 /usr/local/bin/llvm-symbolizer -- GitLab From f9d0eddb90d80b2f33da7a60194fac38dc4a19ba Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 14 Sep 2023 21:04:10 -0500 Subject: [PATCH 018/112] Add fp16/fp8 support into Grouped gemm FixedNK (#874) * move all arguments into device * add b2c_tile_map * add examples * add SetDeviceKernelArgs * dedicated fixed_nk solution * init client api * add grouped_gemm_bias example * add a instance * add instances * formatting * fixed cmake * Update EnableCompilerWarnings.cmake * Update cmake-ck-dev.sh * clean; fixed comments * fixed comment * add instances for fp32 output * add instances for fp32 output * add fp32 out client example * fixed CI * init commit for kbatch * add splitk gridwise * format * fixed * clean deviceop * clean code * finish splitk * fixed instances * change m_loops to tile_loops * add setkbatch * clean code * add splitK+bias * add instances * opt mk_nk instances * clean examples * fixed CI * remove zero * finished non-zero * clean * clean code * optimized global_barrier * fixed ci * fixed CI * instance and client * removed AddBias * format * fixed CI * fixed CI * move 20_grouped_gemm to 21_grouped_gemm * clean * formatting * clean * clean * fixed computeType --------- Co-authored-by: Jing Zhang --- client_example/22_grouped_gemm/CMakeLists.txt | 8 + .../grouped_gemm_fixed_nk_fp16.cpp | 238 +++++++++++++ .../grouped_gemm_fixed_nk_fp8.cpp | 238 +++++++++++++ .../grouped_gemm_fixed_nk_i8.cpp | 238 +++++++++++++ example/15_grouped_gemm/CMakeLists.txt | 5 + .../grouped_gemm_xdl_fixed_nk_fp8.cpp | 330 ++++++++++++++++++ .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 3 + .../element/unary_element_wise_operation.hpp | 12 + ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 46 +-- .../gpu/grouped_gemm_fixed_nk.hpp | 190 ++++++++++ .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 10 + ...fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp | 75 ++++ ...fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp | 78 +++++ ..._fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp | 75 ++++ ..._fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp | 78 +++++ ..._fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp | 75 ++++ ..._fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp | 78 +++++ 17 files changed, 1749 insertions(+), 28 deletions(-) create mode 100644 client_example/22_grouped_gemm/CMakeLists.txt create mode 100644 client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp create mode 100644 client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp create mode 100644 client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp create mode 100644 example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt new file mode 100644 index 000000000..05b9e1e29 --- /dev/null +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_operations) + +add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_operations) diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp new file mode 100644 index 000000000..3503ae8b2 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + // Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; + Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + + int group_count = Ms.size(); + + for(int i = 0; i < group_count; ++i) + { + Ns.push_back(768); + Ks.push_back(4608); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp new file mode 100644 index 000000000..b288550b7 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; + + int group_count = Ms.size(); + + for(int i = 0; i < group_count; ++i) + { + Ns.push_back(768); + Ks.push_back(4608); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 16); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp new file mode 100644 index 000000000..c60daa3b3 --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using I8 = int8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = I8; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + Ms = {167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; + + int group_count = Ms.size(); + + for(int i = 0; i < group_count; ++i) + { + Ns.push_back(768); + Ks.push_back(4608); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 32); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 2f880af3c..45e02126b 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -25,6 +25,11 @@ if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) endif() +if(DTYPES MATCHES "f8" OR NOT DEFINED DTYPES) + add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) +endif() + if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp new file mode 100644 index 000000000..1c50dc051 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp @@ -0,0 +1,330 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_Fixed_NK + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + problem_size.Ms = { + 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148}; + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ns.push_back(768); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + exit(0); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 8cea09ee5..56132f7a0 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -193,6 +193,7 @@ template struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK(x); } + template <> + __host__ __device__ void operator()(float& y, const half_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int8_t& x) const { y = x; } + template <> + __host__ __device__ void operator()(half_t& y, const int8_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const int32_t& x) const { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index e22391293..a3343bd3a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -29,7 +29,9 @@ namespace ck { // E = cde_op(C, D0, D1, ...) // Assume: // D0, D1, ... and E have the same layout -template ())>; - // denorm test fix, required to work around fp16 mfma issue - // we convert fp16->fp32->bf16 and execute bf16 mfma instruction - // when mfma if fixed, remove this section and update - // ABDataTypeAdjusted -> ABDataType throughout this file -#if CK_WORKAROUND_DENORM_FIX - using ABDataTypeAdjusted = - conditional_t, ck::bhalf_t, ABDataType>; -#else - using ABDataTypeAdjusted = ABDataType; -#endif - __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() { // A matrix in LDS memory, dst of blockwise copy @@ -196,7 +187,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ABDataType), + sizeof(ComputeType), c_block_size * sizeof(CShuffleDataType)); } @@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // check tensor size: cannot be larger than 2GB each constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && - b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEElementwiseOperation_, typename Block2ETileMap> - __device__ static void Run(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, @@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle Sequence<1, AK0PerBlock, MPerBlock, AK1>, ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + ADataType, + ComputeType, decltype(a_grid_desc_kbatch_ak0_m_ak1), decltype(a_block_desc_kbatch_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle Sequence<1, BK0PerBlock, NPerBlock, BK1>, BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, - ABDataType, - ABDataTypeAdjusted, + BDataType, + ComputeType, decltype(b_grid_desc_kbatch_bk0_n_bk1), decltype(b_block_desc_kbatch_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // sanity check constexpr index_t KPack = math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ABDataTypeAdjusted, + ComputeType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); @@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle const index_t KBatch, const Block2ETileMap& block_2_etile_map) { - const auto p_a_grid = reinterpret_cast(p_a_grid_); - const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); const auto p_e_grid = reinterpret_cast(p_e_grid_); using DsGridDesc_M_N = diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp new file mode 100644 index 000000000..e8c368cb3 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// fp16_output +void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +// fp8_inputB +void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& instances); + +// i8_inputB +void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedGemmFixedNK> +{ + using DeviceOp = DeviceGroupedGemmFixedNK; + + static auto GetInstances() + { + std::vector> op_ptrs; + + // fp16_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // fp8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // i8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt new file mode 100644 index 000000000..45f8130c5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -0,0 +1,10 @@ +add_instance_library(device_grouped_gemm_fixed_nk_instance + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp + + device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp + + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..584d2be17 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; + +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 000000000..9f0e47fe5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; + +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F16, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..885ca8c63 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; + +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp new file mode 100644 index 000000000..9f1ee423e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; + +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, F8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..fb1ba51fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; + +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp new file mode 100644 index 000000000..23b88c2a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; + +using DsLayout = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, F16, I8, F32, F32, DsDataType, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From bc2d0583d38228bd4bfc51a3d88d31f2e4388d5c Mon Sep 17 00:00:00 2001 From: Bartlomiej Kocot Date: Wed, 13 Sep 2023 15:34:25 +0000 Subject: [PATCH 019/112] Stylistic improvements for grouped convolution code Remove unnecessary ignoring Update test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp --- .../gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp | 4 ---- .../grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp | 1 - .../test_grouped_convnd_bwd_weight.cpp | 4 ++-- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 014140c46..ef57e6e4d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -110,10 +110,6 @@ __global__ void ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; ignore = block_2_ctile_map; ignore = compute_ptr_offset_of_batch; - - compute_ptr_offset_of_batch.GetAPtrOffset(0); - compute_ptr_offset_of_batch.GetBPtrOffset(0); - compute_ptr_offset_of_batch.GetCPtrOffset(0); #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp index 1ebd0c3d6..abe4bb1d3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_dl_instance.hpp @@ -54,7 +54,6 @@ using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple< // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instances - // TODO: Change to ScalarPerVector = 1 when inner_product will be supported DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 8, 16, 4, 2, 1, 1, 2, 1, S<4, 2>, S<1, 1>, S<2, 1, 2, 1>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 1, 1, 1>, S<2, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>, DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1>, diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp index bf95727fb..213202032 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp @@ -33,8 +33,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) { - // K or C are odd is supported only by DL kernel (only applies to fp16) - // DL kernel is only supported for split_k=1 + // Odd K or C values are supported only by DL kernel (only applies to fp16) + // DL kernel currently supports only `split_k=1` if constexpr(std::is_same_v) { if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0)) -- GitLab From f17af2e9ed7e9a9625e415927c8cb6b3efa1b172 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Sun, 17 Sep 2023 20:56:27 -0500 Subject: [PATCH 020/112] Add native conversions fp8<->fp32 (#908) * Add native conversions * Add bf8 conversions --- include/ck/utility/f8_utils.hpp | 5 +- include/ck/utility/type_convert.hpp | 113 ++++++++++++++++++++++++++-- 2 files changed, 110 insertions(+), 8 deletions(-) diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 5fbebb708..217b339b6 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -5,6 +5,8 @@ #include "ck/utility/data_type.hpp" +// these conversions are disabled if native conversions available +#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 namespace ck { @@ -242,4 +244,5 @@ __host__ __device__ Y cast_from_f8(X x) } } // namespace ck::utils -#endif +#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8 +#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 5c5447f94..70619ee0a 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -85,6 +85,19 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ template <> inline __host__ __device__ f8_t type_convert(float x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -92,20 +105,33 @@ inline __host__ __device__ f8_t type_convert(float x) return utils:: cast_to_f8(x, rng); +#endif } // convert fp8 to fp32 template <> inline __host__ __device__ float type_convert(f8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); + // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } // convert fp16 to fp8 template <> inline __host__ __device__ f8_t type_convert(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -113,14 +139,20 @@ inline __host__ __device__ f8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); +#endif } // convert fp8 to fp16 template <> inline __host__ __device__ half_t type_convert(f8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // use native conversion to float and convert to fp16 + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } #endif @@ -129,6 +161,19 @@ inline __host__ __device__ half_t type_convert(f8_t x) template <> inline __host__ __device__ bf8_t type_convert(float x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -136,20 +181,33 @@ inline __host__ __device__ bf8_t type_convert(float x) return utils:: cast_to_f8( x, rng); +#endif } // convert bf8 to fp32 template <> inline __host__ __device__ float type_convert(bf8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); + // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } // convert fp16 to bf8 template <> inline __host__ __device__ bf8_t type_convert(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::standard; @@ -157,14 +215,20 @@ inline __host__ __device__ bf8_t type_convert(half_t x) return utils:: cast_to_f8( x, rng); +#endif } // convert bf8 to fp16 template <> inline __host__ __device__ half_t type_convert(bf8_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // use native conversion to float and convert to fp16 + return type_convert(type_convert(x)); +#else constexpr bool negative_zero_nan = true; return utils::cast_from_f8(x); +#endif } #endif @@ -234,30 +298,47 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_t f8_convert_sr(float x) { + constexpr int seed = 42; + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; - // as thread id is not available on host, use 0 for prn generation - uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8(x, rng); +#endif } // convert fp16 to fp8 with stochastic rounding template <> inline __host__ __device__ f8_t f8_convert_sr(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return f8_convert_sr(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; constexpr int seed = 42; - // as thread id is not available on host, use 0 for prn generation uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( x, rng); +#endif } #endif @@ -266,21 +347,38 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_t f8_convert_sr(float x) { + constexpr int seed = 42; + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; - // as thread id is not available on host, use 0 for prn generation - uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( x, rng); +#endif } // convert fp16 to bf8 with stochastic rounding template <> inline __host__ __device__ bf8_t f8_convert_sr(half_t x) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // convert to float and use native converion + return f8_convert_sr(type_convert(x)); +#else constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; @@ -290,6 +388,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) return utils:: cast_to_f8( x, rng); +#endif } #endif -- GitLab From 63cd4592487f9e1ba32ed7c62990faee375c504a Mon Sep 17 00:00:00 2001 From: Bartlomiej Wroblewski Date: Mon, 18 Sep 2023 14:08:23 +0200 Subject: [PATCH 021/112] Fix DL GEMM instances with too large vector size (#901) * Fix vector lengths of DL GEMM instances with padding * Add checks for correctness of vector lenghts in DL GEMM --- .../gpu/device/impl/device_gemm_dl.hpp | 51 +++++++++++++++++++ ...16_f16_f16_km_kn_mn_irregular_instance.cpp | 28 +++++----- ...16_f16_f16_km_nk_mn_irregular_instance.cpp | 28 +++++----- ...16_f16_f16_mk_kn_mn_irregular_instance.cpp | 28 +++++----- ...l_i8_i8_i8_km_kn_mn_irregular_instance.cpp | 44 ++++++++-------- ...l_i8_i8_i8_km_nk_mn_irregular_instance.cpp | 44 ++++++++-------- ...l_i8_i8_i8_mk_kn_mn_irregular_instance.cpp | 44 ++++++++-------- 7 files changed, 159 insertions(+), 108 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp index d9f8324bd..514aa4452 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -273,6 +273,9 @@ struct DeviceGemmDl : public DeviceGemm::value) + { + constexpr auto A_K_vec_length = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3); + if(arg.K_raw_ % A_K_vec_length != 0) + { + return false; + } + } + else + { + constexpr auto A_M_vec_lenght = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2); + if(arg.M_raw_ % A_M_vec_lenght != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + constexpr auto B_N_vec_lenght = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2); + if(arg.N_raw_ % B_N_vec_lenght != 0) + { + return false; + } + } + else + { + constexpr auto B_K_vec_length = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3); + if(arg.K_raw_ % B_K_vec_length != 0) + { + return false; + } + } + if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102") diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp index d444e29aa..8e39b47f0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp @@ -35,26 +35,26 @@ using device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances = std::tuple< // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // MPerBlock=128, NPerBlock=128 - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=16, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=16 - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=8 - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=8 - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp index f8f8a0bd3..c18d8ae73 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp @@ -35,26 +35,26 @@ using device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances = std::tuple< // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // MPerBlock=128, NPerBlock=128 - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=16, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=16 - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=8 - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=8 - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp index 840a4fabe..1959f2213 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp @@ -35,26 +35,26 @@ using device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances = std::tuple< // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // MPerBlock=128, NPerBlock=128 - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 2, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 2>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<2, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=16, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=16 - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 16, 16, 2, 4, 1, 1, S<4, 2>, S<4, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<16, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=64 - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<4, 1>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=8 - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<4, 1>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=8 - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp index d1eb8edf9..f88d9a9a6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp @@ -32,36 +32,36 @@ using device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances = std::tuple< // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // MPerBlock=128, NPerBlock=128 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=128, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=128 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=32, NPerBlock=32 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=16, NPerBlock=16 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 4, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 4, 4>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 4, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=8 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 4, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=8 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 4, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<1, 1, 4, 4>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp index 3de5458e0..ba30ac8ed 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp @@ -32,36 +32,36 @@ using device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances = std::tuple< // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // MPerBlock=128, NPerBlock=128 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=128, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=128 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=32, NPerBlock=32 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=16, NPerBlock=16 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=8 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=8 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp index 745a4bb31..ad810276f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp @@ -32,36 +32,36 @@ using device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances = std::tuple< // #########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | // #########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // MPerBlock=128, NPerBlock=128 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<8, 2>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<4, 4>, S<4, 2>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 128, 128, 128, 16, 4, 4, 8, 1, S<2, 8>, S<2, 4>, S<8, 1, 2, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 8, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=128, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 128, 64, 16, 4, 4, 2, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=128 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<4, 4>, S<4, 4>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 256, 64, 128, 16, 4, 2, 4, 1, S<2, 8>, S<2, 8>, S<8, 1, 1, 4>, S<2, 1, 64, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<8, 1>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 64, 8, 4, 4, 4, 1, S<4, 2>, S<8, 1>, S<4, 1, 2, 4>, S<2, 1, 32, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=32, NPerBlock=32 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 2, 4, 1, S<4, 2>, S<2, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<4, 2>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 32, 32, 32, 8, 4, 4, 2, 1, S<2, 2>, S<2, 4>, S<4, 1, 2, 4>, S<2, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<2, 1, 4, 4>, S<4, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=16, NPerBlock=16 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<2, 2>, S<2, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 16, 16, 16, 16, 2, 2, 2, 1, S<4, 1>, S<4, 1>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 4, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=64 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 8, 64, 32, 2, 1, 2, 1, S<2, 2>, S<8, 2>, S<4, 1, 1, 2>, S<8, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=64, NPerBlock=8 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 64, 64, 8, 32, 2, 2, 1, 1, S<8, 2>, S<2, 2>, S<8, 1, 4, 2>, S<4, 1, 16, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<32, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, // MPerBlock=8, NPerBlock=8 - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, - DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<4, 1>, S<2, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 1, 2, 1, S<1, 4>, S<1, 2>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<2, 1>, S<4, 1>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>, + DeviceGemmDl< int8_t, int8_t, int8_t, int32_t, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 8, 8, 8, 4, 2, 2, 1, 1, S<1, 2>, S<1, 4>, S<4, 1, 1, 2>, S<1, 1, 8, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<4, 1, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 2>, S<1, 1, 4, 2>, S<4, 1, 2, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 1, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1> // clang-format on >; -- GitLab From 5a4416c8a7b6e28257a580d212373bd54a4aae3d Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 19 Sep 2023 09:17:39 -0700 Subject: [PATCH 022/112] fix the ckprofiler package build in a loop (#926) --- CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e3b60347d..32defea75 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,6 +105,7 @@ if(NOT PROFILER_ONLY) TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() add_definitions(-DPROFILER_ONLY) + set(GPU_TARGETS "" CACHE STRING "" FORCE) if(GPU_TARGETS) message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9, gfx10, or gfx11") endif() @@ -117,11 +118,12 @@ else() else() message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx9, gfx10, or gfx11") endif() + set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}") -set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " ") +set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) if(GPU_TARGETS) message("Building CK for the following targets: ${GPU_TARGETS}") -- GitLab From 718065ebd2509251291a3c3adeb7d0cef2ecb8dc Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 19 Sep 2023 09:35:45 -0700 Subject: [PATCH 023/112] update to rocm5.7 by default (#925) * update to rocm5.7 by default * fix jenkinsfile syntax --- Dockerfile | 25 ++++++------------------- Jenkinsfile | 10 +++++----- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/Dockerfile b/Dockerfile index b0954fdc7..4ba13649e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,25 +16,12 @@ RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN wget https://repo.radeon.com/amdgpu-install/5.6/ubuntu/focal/amdgpu-install_5.6.50600-1_all.deb --no-check-certificate -RUN apt-get update && \ -DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ - ./amdgpu-install_5.6.50600-1_all.deb - -RUN if [ "$ROCMVERSION" != "5.7" ]; then \ - wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ - sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "" ] || [ "$compiler_version" = "amd-stg-open" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ - apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ - amdgpu-repo --amdgpu-build=1609671 --rocm-build=compute-rocm-npi-mi300/1354; \ - elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "rc1" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \ - apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 5.7 rel-19 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1637781; \ - fi +RUN wget https://repo.radeon.com/amdgpu-install/5.7/ubuntu/focal/amdgpu-install_5.7.50700-1_all.deb --no-check-certificate +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_5.7.50700-1_all.deb + +RUN wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ + sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ + sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list' RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN amdgpu-install -y --usecase=rocm --no-dkms diff --git a/Jenkinsfile b/Jenkinsfile index c1a28a74e..0fdadf8a8 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -33,7 +33,7 @@ def runShell(String command){ def getDockerImageName(){ def img - if (params.ROCMVERSION != "5.7"){ + if (params.ROCMVERSION != "6.0"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -613,8 +613,8 @@ def process_results(Map conf=[:]){ } //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version -CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1 - 0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT= +CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION= + 0 21 * * * % ROCMVERSION=5.7;COMPILER_VERSION=;COMPILER_COMMIT= 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : "" pipeline { @@ -632,8 +632,8 @@ pipeline { description: "Force building docker image (default: false), set to true if docker image needs to be updated.") string( name: 'ROCMVERSION', - defaultValue: '5.6', - description: 'Specify which ROCM version to use: 5.6 (default).') + defaultValue: '5.7', + description: 'Specify which ROCM version to use: 5.7 (default).') string( name: 'COMPILER_VERSION', defaultValue: '', -- GitLab From 58817bf967d3b7b314233f6331f0430cf40651ae Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 19 Sep 2023 18:50:58 -0700 Subject: [PATCH 024/112] fix the building of the amd-stg-open compiler (#927) --- Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 4ba13649e..4f227d036 100644 --- a/Dockerfile +++ b/Dockerfile @@ -101,7 +101,7 @@ RUN sh -c "echo compiler commit = '$compiler_commit'" RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi @@ -109,7 +109,7 @@ RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" = "" ]; RUN if [ "$compiler_version" = "amd-stg-open" ] && [ "$compiler_commit" != "" ]; then \ git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ - cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld;compiler-rt" ../llvm && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ else echo "using the release compiler"; \ fi -- GitLab From bba085d2b58da0b222e784b8a335c3a24acbacc9 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 20 Sep 2023 22:15:56 -0700 Subject: [PATCH 025/112] Refactoring cmake files to build data types separately. (#932) * refactor cmake files for the tests * refactor cmake files for examples * fix cmake for gemm example * fix the cmake file for all examples * add splitting by data types in gemm_splitk instance header * rename test to reflect only dl instances are used * clean up CI workspace, update cmake for instances * change the jenkinsfile syntax * build all instances except DL on gfx11 * move workspace cleanup after stages * clean up workspace after every stage * isolate data types in grouped_conv_fwd header * isolate dl instances for grouped_conv2d_fwd * fix syntax * fix cmake and batchnorm instances * fix typo * fix reduction instances * fix grouped_conv headers * fix syntax * replace parsing logic for instances, replace bfp16 with bf16 * fix the client examples build * clean up DTYPES from instances cmake files * update the parsing logic in cmake files * make an exception for reduction kernels * update few remaining cmake files to handle DTYPES * fix syntax * fix cmake conflicts * replace f8 with fp8 test name * resolve conflicts for dpp instances --- CMakeLists.txt | 15 +- example/01_gemm/CMakeLists.txt | 82 +++---- example/02_gemm_bilinear/CMakeLists.txt | 2 - example/03_gemm_bias_relu/CMakeLists.txt | 2 - .../04_gemm_add_add_fastgelu/CMakeLists.txt | 16 +- example/09_convnd_fwd/CMakeLists.txt | 26 +-- .../CMakeLists.txt | 32 +-- example/13_pool2d_fwd/CMakeLists.txt | 8 +- example/14_gemm_quantization/CMakeLists.txt | 7 +- example/15_grouped_gemm/CMakeLists.txt | 58 ++--- ...dl_bfp16.cpp => grouped_gemm_xdl_bf16.cpp} | 0 .../CMakeLists.txt | 43 ++-- example/17_convnd_bwd_data/CMakeLists.txt | 13 +- example/18_batched_gemm_reduce/CMakeLists.txt | 2 - .../20_grouped_conv_bwd_weight/CMakeLists.txt | 20 +- example/21_gemm_layernorm/CMakeLists.txt | 3 +- example/22_cgemm/CMakeLists.txt | 16 +- example/24_batched_gemm/CMakeLists.txt | 22 +- ...dl_bfp16.cpp => batched_gemm_xdl_bf16.cpp} | 0 example/25_gemm_bias_e_permute/CMakeLists.txt | 6 +- example/26_contraction/CMakeLists.txt | 12 +- example/27_layernorm/CMakeLists.txt | 6 +- .../CMakeLists.txt | 4 +- .../CMakeLists.txt | 8 +- .../CMakeLists.txt | 32 +-- example/31_batched_gemm_gemm/CMakeLists.txt | 18 +- .../CMakeLists.txt | 35 +-- example/35_splitK_gemm/CMakeLists.txt | 23 +- ...xdl_bfp16.cpp => splitK_gemm_xdl_bf16.cpp} | 0 .../CMakeLists.txt | 4 +- .../CMakeLists.txt | 11 +- example/39_permute/CMakeLists.txt | 15 +- .../40_conv2d_fwd_quantization/CMakeLists.txt | 4 - .../41_grouped_conv_conv_fwd/CMakeLists.txt | 16 +- example/42_groupnorm/CMakeLists.txt | 8 +- .../CMakeLists.txt | 8 +- example/44_elementwise_permute/CMakeLists.txt | 6 +- example/46_gemm_add_multiply/CMakeLists.txt | 8 +- example/48_pool3d_fwd/CMakeLists.txt | 4 +- example/49_maxpool2d_bwd/CMakeLists.txt | 12 +- example/50_put_element/CMakeLists.txt | 4 +- example/CMakeLists.txt | 114 +++++++++- .../gpu/batchnorm_backward.hpp | 50 +++-- .../gpu/batchnorm_forward.hpp | 44 ++-- .../gpu/batchnorm_infer.hpp | 44 ++-- .../gpu/gemm_splitk.hpp | 17 +- .../gpu/grouped_convolution_backward_data.hpp | 60 ++++-- .../grouped_convolution_backward_weight.hpp | 133 +++++++++--- .../gpu/grouped_convolution_forward.hpp | 175 ++++++++++----- .../gpu/reduce/device_reduce_instance.hpp | 59 +++--- .../gpu/CMakeLists.txt | 102 +++++++-- .../gpu/avg_pool3d_bwd/CMakeLists.txt | 12 +- .../gpu/batched_gemm/CMakeLists.txt | 22 +- .../CMakeLists.txt | 2 - .../batched_gemm_bias_permute/CMakeLists.txt | 3 +- .../gpu/batched_gemm_gemm/CMakeLists.txt | 2 - .../gpu/batched_gemm_multi_d/CMakeLists.txt | 44 ++-- .../gpu/batched_gemm_reduce/CMakeLists.txt | 2 - .../batched_gemm_softmax_gemm/CMakeLists.txt | 2 - .../CMakeLists.txt | 14 +- .../gpu/contraction_bilinear/CMakeLists.txt | 27 ++- .../gpu/contraction_scale/CMakeLists.txt | 26 ++- .../gpu/conv2d_bwd_data/CMakeLists.txt | 29 +-- .../gpu/conv2d_fwd/CMakeLists.txt | 20 +- .../elementwise_normalization/CMakeLists.txt | 2 - .../gpu/gemm/CMakeLists.txt | 200 ++++++++---------- .../gpu/gemm_add_add_fastgelu/CMakeLists.txt | 2 - .../gpu/gemm_add_fastgelu/CMakeLists.txt | 2 - .../CMakeLists.txt | 2 - .../gpu/gemm_bilinear/CMakeLists.txt | 2 - .../gpu/gemm_fastgelu/CMakeLists.txt | 2 - .../gpu/gemm_multiply_add/CMakeLists.txt | 15 +- .../gpu/gemm_splitk/CMakeLists.txt | 40 ++-- .../gpu/gemm_streamk/CMakeLists.txt | 2 - .../gpu/grouped_gemm/CMakeLists.txt | 2 - .../gpu/grouped_gemm_fastgelu/CMakeLists.txt | 2 - .../gpu/max_pool_bwd/CMakeLists.txt | 12 +- .../gpu/normalization/CMakeLists.txt | 13 +- .../gpu/pool3d_fwd/CMakeLists.txt | 18 +- .../gpu/quantization/CMakeLists.txt | 24 +-- .../gpu/softmax/CMakeLists.txt | 11 +- test/CMakeLists.txt | 121 +++++++++-- test/batched_gemm/CMakeLists.txt | 28 ++- test/batched_gemm_gemm/CMakeLists.txt | 8 +- test/batched_gemm_multi_d/CMakeLists.txt | 4 +- ...d.cpp => test_batched_gemm_multi_d_dl.cpp} | 0 test/batched_gemm_reduce/CMakeLists.txt | 7 +- test/batched_gemm_softmax_gemm/CMakeLists.txt | 14 +- .../CMakeLists.txt | 41 ++-- test/data_type/CMakeLists.txt | 15 +- test/data_type/{f8.cpp => fp8.cpp} | 0 test/elementwise_normalization/CMakeLists.txt | 6 +- test/gemm/CMakeLists.txt | 30 ++- test/gemm_layernorm/CMakeLists.txt | 10 +- test/gemm_reduce/CMakeLists.txt | 7 +- test/grouped_gemm/CMakeLists.txt | 2 - test/normalization/CMakeLists.txt | 24 ++- test/reduce/CMakeLists.txt | 6 +- 98 files changed, 1281 insertions(+), 1007 deletions(-) rename example/15_grouped_gemm/{grouped_gemm_xdl_bfp16.cpp => grouped_gemm_xdl_bf16.cpp} (100%) rename example/24_batched_gemm/{batched_gemm_xdl_bfp16.cpp => batched_gemm_xdl_bf16.cpp} (100%) rename example/35_splitK_gemm/{splitK_gemm_xdl_bfp16.cpp => splitK_gemm_xdl_bf16.cpp} (100%) rename test/batched_gemm_multi_d/{test_batched_gemm_multi_d.cpp => test_batched_gemm_multi_d_dl.cpp} (100%) rename test/data_type/{f8.cpp => fp8.cpp} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 32defea75..95ef2df7e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,5 @@ cmake_minimum_required(VERSION 3.14) +cmake_policy(SET CMP0140 NEW) # This has to be initialized before the project() command appears # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE @@ -383,31 +384,31 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu set(cmake_instance) file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance) set(add_inst 0) - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") + if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") #message("fp8 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8") + if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8") #message("bf8 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") + if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") #message("fp16 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") + if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") #message("fp32 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") + if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") #message("fp64 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") + if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16") #message("bf16 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") + if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") #message("int8 instance found!") set(add_inst 1) endif() diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 5574d0900..2351d3f10 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -1,51 +1,54 @@ -if(DL_KERNELS) - add_custom_target(example_gemm_dl) +add_custom_target(example_gemm_dl) - add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) - add_dependencies(example_gemm_dl example_gemm_dl_fp32) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp) +if(result EQUAL 0) + add_dependencies(example_gemm_dl example_gemm_dl_fp32) +endif() +add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_dl example_gemm_dl_fp16) - add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +endif() +add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp) +add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_dl example_gemm_dl_int8) - endif() - - if(USE_BITINT_EXTENSION_INT4) +endif() +if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp) add_dependencies(example_gemm_dl example_gemm_dl_int4) - endif(USE_BITINT_EXTENSION_INT4) -endif() +endif(USE_BITINT_EXTENSION_INT4) add_custom_target(example_gemm_xdl) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) - add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) - add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) - add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) - add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) - - if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_gemm_xdl example_gemm_xdl_fp16) +endif() +add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) +endif() +add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) +endif() +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") add_custom_target(example_gemm_wmma) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) - endif() - + if(result EQUAL 0) + add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) + endif() endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_xdl example_gemm_xdl_int8) endif() @@ -54,22 +57,23 @@ if(USE_BITINT_EXTENSION_INT4) add_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) endif() add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) - if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) + +if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") + add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_xdl example_gemm_xdl_f8) endif() endif() -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) +add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8) endif() diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index 52e633052..d82c42d5a 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) @@ -19,4 +18,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index a247a052c..2f5cba924 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 15ec62c89..9fe833dda 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -3,24 +3,24 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_gemm_add_add_fastgelu_xdl) - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) endif() set(target 1) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 1af1e6c85..f9903bfe0 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -2,34 +2,16 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) - endif() - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) - endif() set(target 1) endif() endforeach() -if(DL_KERNELS) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) - endif() -endif() +add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) +add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) +add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index e7d941ae6..150d146e3 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -2,27 +2,27 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + add_custom_target(example_convnd_fwd_reduce_xdl) + add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) + if(result EQUAL 0) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + endif() + add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) + if(result EQUAL 0) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + endif() + add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + endif() + add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) + if(result EQUAL 0) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - endif() - if(USE_BITINT_EXTENSION_INT4) + endif() + if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) + endif(USE_BITINT_EXTENSION_INT4) + set(target 1) endif() endforeach() \ No newline at end of file diff --git a/example/13_pool2d_fwd/CMakeLists.txt b/example/13_pool2d_fwd/CMakeLists.txt index d0f356757..e2a923cde 100644 --- a/example/13_pool2d_fwd/CMakeLists.txt +++ b/example/13_pool2d_fwd/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) -endif() +add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp) +add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 3b3ad80dd..9793e8b8a 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,9 +1,5 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) # dlops -if(DL_KERNELS) - add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) -endif() - +add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) # xdlops list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) @@ -14,4 +10,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() \ No newline at end of file diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 45e02126b..f2c76b76f 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -1,36 +1,44 @@ add_custom_target(example_grouped_gemm_xdl) - -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) +add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp) +if(result EQUAL 0) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32) endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) - add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) - add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) - add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) - add_dependencies(example_grouped_gemm_xdl - example_grouped_gemm_xdl_fp16 - example_grouped_gemm_multiple_d_dl_fp16 - example_grouped_gemm_xdl_splitk_fp16 - example_grouped_gemm_xdl_fixed_nk_fp16 - example_grouped_gemm_xdl_fixed_nk_bias_fp16) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) +add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16) +endif() +add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16) +endif() +add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16) +endif() +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16) +endif() +add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16) +endif() +add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp) +if(result EQUAL 0) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) +endif() +add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) +if(result EQUAL 0) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) endif() -if(DTYPES MATCHES "f8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) +if(result EQUAL 0) add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) - add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) + if(result EQUAL 0) + add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) + endif() endif() diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp similarity index 100% rename from example/15_grouped_gemm/grouped_gemm_xdl_bfp16.cpp rename to example/15_grouped_gemm/grouped_gemm_xdl_bf16.cpp diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 00786d34a..67cf5666d 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -6,30 +6,43 @@ foreach(gpu IN LISTS GPU_TARGETS) add_custom_target(example_gemm_reduce_xdl_max) add_custom_target(example_gemm_reduce_xdl_mean_meansquare) add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) + endif() + add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + endif() + add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + + add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + endif() + add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + + add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + endif() + add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + + add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + endif() + add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) + if(result EQUAL 0) add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) endif() @@ -40,7 +53,9 @@ foreach(gpu IN LISTS GPU_TARGETS) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) + if(result EQUAL 0) + add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) + endif() endif() set(target 1) endif() diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index e187bd433..7c6d10d8a 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,15 +1,16 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) - target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) + endif() set(target 1) endif() endforeach() - if(DL_KERNELS) - add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) - target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) - endif() + +add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility) endif() diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt index a1bb398af..94ed129dc 100644 --- a/example/18_batched_gemm_reduce/CMakeLists.txt +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index d649567ed..ab8ed5c2b 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -3,22 +3,20 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_grouped_conv_bwd_weight) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) + if(result EQUAL 0) add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) endif() set(target 1) endif() endforeach() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - add_custom_target(example_grouped_conv_bwd_weight_dl) - add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) - add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) - endif() -endif() \ No newline at end of file +add_custom_target(example_grouped_conv_bwd_weight_dl) +add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) +if(result EQUAL 0) + add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +endif() diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index 6a6735efd..e231bc619 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -10,4 +9,4 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() + diff --git a/example/22_cgemm/CMakeLists.txt b/example/22_cgemm/CMakeLists.txt index 854f07fda..f401f7187 100644 --- a/example/22_cgemm/CMakeLists.txt +++ b/example/22_cgemm/CMakeLists.txt @@ -1,19 +1,19 @@ add_custom_target(example_cgemm_xdl) -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) +add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp) +if(result EQUAL 0) add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16) endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16) endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp) -add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) +if(result EQUAL 0) + add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32) endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) +add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp) +if(result EQUAL 0) add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8) endif() if(USE_BITINT_EXTENSION_INT4) diff --git a/example/24_batched_gemm/CMakeLists.txt b/example/24_batched_gemm/CMakeLists.txt index 48a3b58ff..c80540de7 100644 --- a/example/24_batched_gemm/CMakeLists.txt +++ b/example/24_batched_gemm/CMakeLists.txt @@ -1,21 +1,23 @@ add_custom_target(example_batched_gemm_xdl) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp) +if(result EQUAL 0) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32) endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16) endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16) +add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp) +if(result EQUAL 0) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16) endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) +add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp) +if(result EQUAL 0) add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8) endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp) - add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) + if(result EQUAL 0) + add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4) + endif() endif() diff --git a/example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16.cpp similarity index 100% rename from example/24_batched_gemm/batched_gemm_xdl_bfp16.cpp rename to example/24_batched_gemm/batched_gemm_xdl_bf16.cpp diff --git a/example/25_gemm_bias_e_permute/CMakeLists.txt b/example/25_gemm_bias_e_permute/CMakeLists.txt index eb274b233..cbc3c007b 100644 --- a/example/25_gemm_bias_e_permute/CMakeLists.txt +++ b/example/25_gemm_bias_e_permute/CMakeLists.txt @@ -1,4 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) - add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) -endif() +add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp) +add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 6cab88b13..167a9f147 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,8 +1,4 @@ -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) - add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -endif() -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) - add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -endif() +add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) +add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) +add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) +add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) diff --git a/example/27_layernorm/CMakeLists.txt b/example/27_layernorm/CMakeLists.txt index 9cb2cd076..94c23ce77 100644 --- a/example/27_layernorm/CMakeLists.txt +++ b/example/27_layernorm/CMakeLists.txt @@ -1,4 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) - add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) -endif() +add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp) +add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp) diff --git a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt index 2fda1f62a..44ab16894 100644 --- a/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt +++ b/example/28_grouped_gemm_bias_e_permute/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) -endif() +add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 09c3e6c60..32a87dd20 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,7 +1,5 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) - if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) - endif() +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) endif() diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index e37413c09..4ab5bd75f 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -5,27 +5,31 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) add_custom_target(example_grouped_conv_fwd_multiple_d) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + endif() + add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) + if(result EQUAL 0) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) + if(result EQUAL 0) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) + if(result EQUAL 0) add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) + if(result EQUAL 0) + add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) + endif() endif() # USE_BITINT_EXTENSION_INT4 set(target 1) @@ -35,12 +39,8 @@ endforeach() set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - endif() + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) set(target 1) endif() endforeach() diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 2074520f8..93f16c945 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,17 +1,11 @@ list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) + set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - endif() + add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) + add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) @@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS) endforeach() if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) - endif() + add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) endif() diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 0463bf6bd..25eb44ae5 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,24 +1,31 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) - add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) -endif() - add_custom_target(example_gemm_scale_softmax_gemm) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) +endif() +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) +endif() +add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) +endif() +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) +endif() +add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) +add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16) +endif() +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) +if(result EQUAL 0) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) endif() + diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 251a9b93c..8970a5764 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -3,25 +3,28 @@ set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_splitK_gemm_xdl) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + + add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) + if(result EQUAL 0) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) + if(result EQUAL 0) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16) + add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) + if(result EQUAL 0) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) + if(result EQUAL 0) add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) endif() if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) + if(result EQUAL 0) + add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) + endif() endif() set(target 1) endif() diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp similarity index 100% rename from example/35_splitK_gemm/splitK_gemm_xdl_bfp16.cpp rename to example/35_splitK_gemm/splitK_gemm_xdl_bf16.cpp diff --git a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt index 36bb5720d..a9be3a710 100644 --- a/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt +++ b/example/37_batched_gemm_add_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp) diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 3821f8aac..1675e3663 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,15 +1,16 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) add_custom_target(example_grouped_conv_bwd_data) add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp) + if(result EQUAL 0) + add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) + endif() add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp) - - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16) - add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) + if(result EQUAL 0) + add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16) + endif() set(target 1) endif() endforeach() -endif() diff --git a/example/39_permute/CMakeLists.txt b/example/39_permute/CMakeLists.txt index 5b43de972..bcf47b492 100644 --- a/example/39_permute/CMakeLists.txt +++ b/example/39_permute/CMakeLists.txt @@ -1,11 +1,14 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_custom_target(example_permute) - - add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) - add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) - add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +add_custom_target(example_permute) +add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_permute example_permute_1xHxW_fp16) +endif() +add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_permute example_permute_NxHxW_fp16) +endif() +add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp) +if(result EQUAL 0) add_dependencies(example_permute example_permute_HxWx4_fp16) endif() diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 55464957a..2d804cafe 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -11,7 +10,6 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() - if(DL_KERNELS) # Conv perlayer quantization add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) # Conv perchannel quantization @@ -24,5 +22,3 @@ endforeach() add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) # Conv + bias + tanh perchannel quantization add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) - endif() -endif() \ No newline at end of file diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index 085d90a9e..ae251e88d 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -3,15 +3,9 @@ list(APPEND gpu_list2 gfx908 gfx90a) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) - endif() + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) + add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) endif(USE_BITINT_EXTENSION_INT4) @@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS) endforeach() if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) - endif() + add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) endif() diff --git a/example/42_groupnorm/CMakeLists.txt b/example/42_groupnorm/CMakeLists.txt index bc2246a2b..e8c306ac5 100644 --- a/example/42_groupnorm/CMakeLists.txt +++ b/example/42_groupnorm/CMakeLists.txt @@ -1,5 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) - add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) - add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) -endif() +add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp) +add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp) +add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp) diff --git a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt index 7e070f535..c29f18f16 100644 --- a/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt +++ b/example/43_splitk_gemm_bias_e_permute/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) -endif() +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp) +add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp) diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index 877a82031..0e0091a98 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -1,4 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) - add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) -endif() +add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp) +add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp) diff --git a/example/46_gemm_add_multiply/CMakeLists.txt b/example/46_gemm_add_multiply/CMakeLists.txt index cf7d81f89..bfe057e8d 100644 --- a/example/46_gemm_add_multiply/CMakeLists.txt +++ b/example/46_gemm_add_multiply/CMakeLists.txt @@ -1,6 +1,2 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) - endif() - add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) -endif() +add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp) +add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp) diff --git a/example/48_pool3d_fwd/CMakeLists.txt b/example/48_pool3d_fwd/CMakeLists.txt index f821f2c97..492cb4d37 100644 --- a/example/48_pool3d_fwd/CMakeLists.txt +++ b/example/48_pool3d_fwd/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) -endif() +add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp) diff --git a/example/49_maxpool2d_bwd/CMakeLists.txt b/example/49_maxpool2d_bwd/CMakeLists.txt index fe98b7c99..b29cf9ccb 100644 --- a/example/49_maxpool2d_bwd/CMakeLists.txt +++ b/example/49_maxpool2d_bwd/CMakeLists.txt @@ -1,9 +1,3 @@ -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) -endif() +add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp) +add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp) +add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp) diff --git a/example/50_put_element/CMakeLists.txt b/example/50_put_element/CMakeLists.txt index eca410008..1b0020ebc 100644 --- a/example/50_put_element/CMakeLists.txt +++ b/example/50_put_element/CMakeLists.txt @@ -1,3 +1 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_example_executable(example_put_element_fp16 put_element_fp16.cpp) -endif() +add_example_executable(example_put_element_fp16 put_element_fp16.cpp) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 1fdd2f6d1..d030c5345 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -7,20 +7,114 @@ add_custom_target(examples) function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - add_dependencies(examples ${EXAMPLE_NAME}) - add_dependencies(check ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + foreach(type IN LISTS DTYPES) + if(type MATCHES "fp16") + set(type1 "_f16") + elseif(type MATCHES "fp32") + set(type1 "_f32") + elseif(type MATCHES "fp8") + set(type1 "_f8") + elseif(type MATCHES "bf16") + set(type1 "_b16") + elseif(type MATCHES "fp64") + set(type1 "_f64") + elseif(type MATCHES "int8") + set(type1 "_i8") + endif() + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() + elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) + #if filename contains a type which doesn't match any selected type, mark it for removal + set(test 1) + endif() + endforeach() + if(test EQUAL 1) + message("removing example source file ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) + add_dependencies(examples ${EXAMPLE_NAME}) + add_dependencies(check ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + return(PROPAGATE result) endfunction(add_example_executable EXAMPLE_NAME) function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") - add_executable(${EXAMPLE_NAME} ${FILE_NAME}) - target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) - add_dependencies(examples ${EXAMPLE_NAME}) - rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS FILE_NAME) + set(test 0) + foreach(type IN LISTS DTYPES) + if(type MATCHES "fp16") + set(type1 "_f16") + elseif(type MATCHES "fp32") + set(type1 "_f32") + elseif(type MATCHES "fp8") + set(type1 "_f8") + elseif(type MATCHES "bf16") + set(type1 "_b16") + elseif(type MATCHES "fp64") + set(type1 "_f64") + elseif(type MATCHES "int8") + set(type1 "_i8") + endif() + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() + elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) + #if filename contains a type which doesn't match any selected type, mark it for removal + set(test 1) + endif() + endforeach() + if(test EQUAL 1) + message("removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS FILE_NAME) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(FILE_NAME) + add_executable(${EXAMPLE_NAME} ${FILE_NAME}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + add_dependencies(examples ${EXAMPLE_NAME}) + rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) + set(result 0) + endif() + #message("add_example returns ${result}") + return(PROPAGATE result) endfunction(add_example_executable_no_testing EXAMPLE_NAME) # add all example subdir diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp index 0e1f6f04e..efb76df25 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_backward.hpp @@ -16,26 +16,26 @@ namespace tensor_operation { namespace device { namespace instance { -// FP16 +#ifdef CK_ENABLE_FP16 void add_device_batchnorm_backward_rank_4_3_f16_instances( std::vector>>&); - -// FP32 +#endif +#ifdef CK_ENABLE_FP32 void add_device_batchnorm_backward_rank_4_3_f32_instances( std::vector>>&); - -// BF16 +#endif +#ifdef CK_ENABLE_BF16 void add_device_batchnorm_backward_rank_4_3_bf16_instances( std::vector>>&); - -// FP64 +#endif +#ifdef CK_ENABLE_FP64 void add_device_batchnorm_backward_rank_4_3_f64_instances( std::vector>>&); - +#endif template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v && is_same_v && @@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory< add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) { add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) { add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP64 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) { add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs); } } - +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp index 8fd1c7665..08e35d722 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp @@ -16,26 +16,26 @@ namespace tensor_operation { namespace device { namespace instance { -// FP16 +#ifdef CK_ENABLE_FP16 void add_device_batchnorm_forward_rank_4_3_f16_instances( std::vector< std::unique_ptr>>&); - -// FP32 +#endif +#ifdef CK_ENABLE_FP32 void add_device_batchnorm_forward_rank_4_3_f32_instances( std::vector< std::unique_ptr>>&); - -// BF16 +#endif +#ifdef CK_ENABLE_BF16 void add_device_batchnorm_forward_rank_4_3_bf16_instances( std::vector< std::unique_ptr>>&); - -// FP64 +#endif +#ifdef CK_ENABLE_FP64 void add_device_batchnorm_forward_rank_4_3_f64_instances( std::vector< std::unique_ptr>>&); - +#endif template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) @@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory< add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) { add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) { add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef CK_ENABLE_FP64 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v) { add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs); } } - +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp index f6f4df7e2..97c216b4a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/batchnorm_infer.hpp @@ -16,38 +16,38 @@ namespace tensor_operation { namespace device { namespace instance { -// FP16 +#ifdef CK_ENABLE_FP16 void add_device_batchnorm_infer_rank_4_f16_instances( std::vector, ck::Tuple, ck::tensor_operation::element_wise::NormalizeInInfer, 4>>>&); - -// FP32 +#endif +#ifdef CK_ENABLE_FP32 void add_device_batchnorm_infer_rank_4_f32_instances( std::vector, ck::Tuple, ck::tensor_operation::element_wise::NormalizeInInfer, 4>>>&); - -// BF16 +#endif +#ifdef CK_ENABLE_BF16 void add_device_batchnorm_infer_rank_4_bf16_instances( std::vector, ck::Tuple, ck::tensor_operation::element_wise::NormalizeInInfer, 4>>>&); - -// FP64 +#endif +#ifdef CK_ENABLE_FP64 void add_device_batchnorm_infer_rank_4_f64_instances( std::vector, ck::Tuple, ck::tensor_operation::element_wise::NormalizeInInfer, 4>>>&); - +#endif template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) @@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 4) { add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 4) { add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs); } } - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP64 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { if constexpr(Rank == 4) { add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs); } } - +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index d239cd770..4cba0875d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP16 void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances( std::vector>>& @@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( std::vector>>& @@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( std::vector>>& instances); - -#if defined CK_ENABLE_FP8 +#endif +#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances( std::vector>>& @@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory< static auto GetInstances() { std::vector> op_ptrs; - +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs); } } +#endif +#ifdef CK_ENABLE_FP16 else if constexpr(is_same_v && is_same_v && is_same_v) { @@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } -#if defined CK_ENABLE_FP8 +#endif +#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) else if constexpr(is_same_v && is_same_v && is_same_v) { @@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory< } } #endif - return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 58c906453..2b49e1718 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -16,6 +16,7 @@ namespace device { namespace instance { // conv2d backward data +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( std::vector>>& instances); - +#endif // conv3d backward data +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( std::vector>>& instances); - +#endif template && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( op_ptrs); } +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); } +#endif } } else if constexpr(NumDimSpatial == 3) @@ -274,46 +298,58 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( op_ptrs); } +#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_FP32 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); } +#endif } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index c7a694b4c..c0a5b4bfe 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -19,6 +19,7 @@ namespace instance { // xdl // conv1d backward weight +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( std::vector>>& instances); - +#endif // conv2d backward weight +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); - +#endif // conv3d backward weight +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); +#endif #ifdef DL_KERNELS // dl // conv1d backward weight +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( std::vector>>& instances); - +#endif // conv2d backward weight +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); - +#endif // conv3d backward weight +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); #endif +#endif template && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -470,6 +506,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -478,6 +516,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -489,21 +529,27 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { #ifdef DL_KERNELS +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_FP16 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) @@ -511,6 +557,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -529,6 +577,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -539,6 +589,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -550,10 +602,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -564,6 +618,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -574,6 +630,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -585,6 +643,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -602,6 +662,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -612,6 +674,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -623,10 +687,12 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { @@ -637,6 +703,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { @@ -647,6 +715,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -658,6 +728,7 @@ struct DeviceOperationInstanceFactory>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 // grouped conv2d forward, GNHWC/GKYXC/GNHWK void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef DL_KERNELS +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances); - +#endif +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); + +#ifdef DL_KERNELS void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); - +#endif +#endif +#ifdef CK_ENABLE_INT8 void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( std::vector>>& instances); +#endif +#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); +#endif // grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP16 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( std::vector>>& instances); +#endif template && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs); } +#endif } else if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); +#ifdef DL_KERNELS add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); +#endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); +#ifdef DL_KERNELS add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); +#endif add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_INT8 else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); } +#endif } else if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); +#ifdef DL_KERNELS add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); +#endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); +#ifdef DL_KERNELS add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); +#endif } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BDF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } +#endif } else if constexpr(NumDimSpatial == 3 && is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); } +#endif } else if constexpr(NumDimSpatial == 3 && is_same_v && is_same_v && is_same_v) { +#ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); } - else if constexpr(is_same_v && - is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); } - else if constexpr(is_same_v && is_same_v && - is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); } +#endif } return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp index 9930b1a6f..4c22b047f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance.hpp @@ -2,13 +2,20 @@ // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once - #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp" @@ -18,39 +25,10 @@ #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp" -#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp" @@ -60,17 +38,38 @@ #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp" +#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp" #include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp" diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 1d54a141b..a0478c9f0 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -1,9 +1,57 @@ function(add_instance_library INSTANCE_NAME) message("adding instance ${INSTANCE_NAME}") - add_library(${INSTANCE_NAME} OBJECT ${ARGN}) - target_compile_features(${INSTANCE_NAME} PUBLIC) - set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) - clang_tidy_check(${INSTANCE_NAME}) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS ARGN) + set(test 0) + foreach(type IN LISTS DTYPES) + if(type MATCHES "fp16") + set(type1 "_f16") + elseif(type MATCHES "fp32") + set(type1 "_f32") + elseif(type MATCHES "fp8") + set(type1 "_f8") + elseif(type MATCHES "bf16") + set(type1 "_b16") + elseif(type MATCHES "fp64") + set(type1 "_f64") + elseif(type MATCHES "int8") + set(type1 "_i8") + endif() + #make an exception for reduction kernels + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() + elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) + #if filename contains a type which doesn't match any selected type, mark it for removal + set(test 1) + endif() + endforeach() + if(test EQUAL 1) + message("removing instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS ARGN) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(ARGN) + add_library(${INSTANCE_NAME} OBJECT ${ARGN}) + target_compile_features(${INSTANCE_NAME} PUBLIC) + set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + clang_tidy_check(${INSTANCE_NAME}) + set(result 0) + endif() + #message("add_instance_library returns ${result}") + return(PROPAGATE result) endfunction(add_instance_library INSTANCE_NAME) @@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}") set(cmake_instance) file(READ "${subdir_path}/CMakeLists.txt" cmake_instance) set(add_inst 0) - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8") - #message("fp8 instance found!") + if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8") + message("fp8 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16") - #message("fp16 instance found!") + if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16") + message("fp16 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32") - #message("fp32 instance found!") + if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32") + message("fp32 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64") - #message("fp64 instance found!") + if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64") + message("fp64 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16") - #message("bf16 instance found!") + if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16") + message("bf16 instance found!") set(add_inst 1) endif() - if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8") - #message("int8 instance found!") + if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8") + message("int8 instance found!") set(add_inst 1) endif() - if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES) - #message("instance should be built for all types!") - set(add_inst 1) + if(NOT "${cmake_instance}" MATCHES "_fp8" OR + NOT "${cmake_instance}" MATCHES "_f8" OR + NOT "${cmake_instance}" MATCHES "_fp16" OR + NOT "${cmake_instance}" MATCHES "_f16" OR + NOT "${cmake_instance}" MATCHES "_fp32" OR + NOT "${cmake_instance}" MATCHES "_f32" OR + NOT "${cmake_instance}" MATCHES "_fp64" OR + NOT "${cmake_instance}" MATCHES "_f64" OR + NOT "${cmake_instance}" MATCHES "_bf16" OR + NOT "${cmake_instance}" MATCHES "_int8" OR + NOT "${cmake_instance}" MATCHES "_i8" OR + NOT "${cmake_instance}" MATCHES "_int4" OR + NOT DEFINED DTYPES) + message("instance should be built for all types!") + set(add_inst 1) + endif() + if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8") + message("quantization instances will not be built!") + set(add_inst 0) endif() if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS) message("Found only dl instances, but DL_KERNELS is not set. Skipping.") diff --git a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt index ec079e3ba..084714b70 100644 --- a/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/avg_pool3d_bwd/CMakeLists.txt @@ -1,11 +1,5 @@ set(DEVICE_AVGPOOL_BWD_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp) -endif() +list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp + device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp + device_avg_pool3d_bwd_ndhwc_f32_instance.cpp) add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 5b342595d..69b6ddc75 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -1,26 +1,18 @@ set(BATCHED_GEMM_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp +list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp - device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp - device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp - device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp + device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp + device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp) -endif() add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index 6710035ec..d0e9b265a 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,6 +1,4 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp ) -endif() \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index b0f37e68f..cd9c95c06 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,5 +1,4 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp ) -endif() + diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index cdb1a5390..865a31e79 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,6 +1,4 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_gemm_instance device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt index 444c93b11..b874bc50e 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_multi_d/CMakeLists.txt @@ -1,25 +1,21 @@ # ONLY DL_KERNELS -if(DL_KERNELS) - set(BATCHED_GEMM_MULTID_INSTANCES) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp) - endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp) - list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp) - endif() - add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES}) -endif() +set(BATCHED_GEMM_MULTID_INSTANCES) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp) + +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp) +list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp) + +add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index 728e35fc3..28226faba 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,8 +1,6 @@ -if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt index 5ac55655d..6244477e1 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,5 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_batched_gemm_softmax_gemm_instance device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index d5110e450..3fd4e0344 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,11 +1,7 @@ set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp) - list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp) - list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp) -endif() +list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp + device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp) add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES}) - diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 1db6985f6..ee9c419e1 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,17 +1,14 @@ set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - #float - list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) -endif() -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - #double - list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) -endif() -add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES}) +#float +list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) + +#double +list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) +add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index aed8bef2a..673fae337 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,17 +1,15 @@ set(DEVICE_CONTRACTION_SCALE_INSTANCES) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - #float - list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp) -endif() -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - #double - list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp) -endif() +#float +list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp) + +#double +list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp + device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp) + add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index 54ef9cc7a..49dfc01fd 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -1,23 +1,10 @@ set(CONV2D_BWD_DATA_INSTANCES) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp) - if(DL_KERNELS) - list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp) - endif() -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) - if(DL_KERNELS) - list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp) - endif() -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp) - if(DL_KERNELS) - list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp) - endif() -endif() +list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp + device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp + device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp) + add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 96ecc9565..ba0ca3251 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,16 +1,8 @@ set(DEVICE_CONV2D_FWD_INSTANCES) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp) -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) - list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp) -endif() - +list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp + device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp) + add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt index ea53b82be..0c7cc2cd3 100644 --- a/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/elementwise_normalization/CMakeLists.txt @@ -1,5 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_elementwise_normalization_instance device_elementwise_normalization_f16_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 6c76ee843..22cf1a462 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -1,113 +1,99 @@ set(GEMM_INSTANCES) -if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp) - if(DL_KERNELS) - list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp) - endif() -endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp) - endif() - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp) -endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - if(DL_KERNELS) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp) - endif() - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp) - list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) -endif() +list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp + device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp + device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp + device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp) -add_instance_library(device_gemm_instance ${GEMM_INSTANCES}) +list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp) + +list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp + device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp + device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp + device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp + device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp + device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp + device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp + device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp + device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp) + +list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp + device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp + device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp + device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp + device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp + device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp + device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp + device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp) +list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - set(ENABLE_PIPELINE_V2_OPT OFF) +add_instance_library(device_gemm_instance ${GEMM_INSTANCES}) + +set(ENABLE_PIPELINE_V2_OPT OFF) - if (ENABLE_PIPELINE_V2_OPT) +if (ENABLE_PIPELINE_V2_OPT) set(MAX_ILP_OPTS -mllvm -amdgpu-enable-max-ilp-scheduling-strategy @@ -137,5 +123,5 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES COMPILE_OPTIONS "${MAX_ILP_OPTS}" COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}") - endif(ENABLE_PIPELINE_V2_OPT) -endif(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +endif(ENABLE_PIPELINE_V2_OPT) + diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index 9028829fe..bbf81a5fa 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,8 +1,6 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_add_add_fastgelu_instance device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 108596680..0beb10e37 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,8 +1,6 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_add_fastgelu_instance device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index 6079f9013..97693a256 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,8 +1,6 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index aef8fe86d..426edeed7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp @@ -9,4 +8,3 @@ add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index 772373dcb..17d27ab15 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,8 +1,6 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_fastgelu_instance device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index 36bd6a4aa..6cbd7528e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,13 +1,6 @@ set(GEMM_MULTIPLY_ADD_INSTANCES) - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp) - list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp) -endif() - -if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) - list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp) - list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp) -endif() - +list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp + device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp) add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 043b28a1b..837e040ea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,28 +1,20 @@ set(GEMM_SPLITK_INSTANCES) -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp) -endif() - -if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp) - list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp) -endif() +list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp) add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt index d9baf3f06..8dd0112a6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_gemm_streamk_instance # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp @@ -9,4 +8,3 @@ add_instance_library(device_gemm_streamk_instance # device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp # device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index f1553e3d5..b973b70aa 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -9,4 +8,3 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index a45bf3639..648f2146c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,8 +1,6 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp ) -endif() diff --git a/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt index 820a63480..d12864398 100644 --- a/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/max_pool_bwd/CMakeLists.txt @@ -1,11 +1,5 @@ set(DEVICE_MAXPOOL_BWD_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_bf16_instance.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f32_instance.cpp) -endif() +list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp + device_max_pool_bwd_bf16_instance.cpp + device_max_pool_bwd_f32_instance.cpp) add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt index 4892a9c18..19818a523 100644 --- a/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/normalization/CMakeLists.txt @@ -1,15 +1,14 @@ set(DEVICE_NORMALIZATION_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f16_instance.cpp + +list(APPEND DEVICE_NORMALIZATION_INSTANCES + device_layernorm2d_f16_instance.cpp device_layernorm4d_f16_instance.cpp device_groupnorm_f16_instance.cpp device_groupnorm_swish_f16_instance.cpp - device_groupnorm_swish_f16_f32_f32_f16_instance.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp + device_groupnorm_swish_f16_f32_f32_f16_instance.cpp + device_layernorm2d_f32_instance.cpp device_layernorm4d_f32_instance.cpp device_groupnorm_f32_instance.cpp device_groupnorm_swish_f32_instance.cpp) -endif() + add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt index 63bbe7546..533ce89df 100644 --- a/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/pool3d_fwd/CMakeLists.txt @@ -1,14 +1,8 @@ set(DEVICE_POOL3D_FWD_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp - device_max_pool3d_fwd_ndhwc_f16_instance.cpp) -endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp - device_max_pool3d_fwd_ndhwc_bf16_instance.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp - device_max_pool3d_fwd_ndhwc_f32_instance.cpp) -endif() +list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp + device_max_pool3d_fwd_ndhwc_f16_instance.cpp + device_avg_pool3d_fwd_ndhwc_f32_instance.cpp + device_max_pool3d_fwd_ndhwc_f32_instance.cpp + device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp + device_max_pool3d_fwd_ndhwc_bf16_instance.cpp) add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index 00eb6ff1c..c22a6e9e9 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,5 +1,3 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp) @@ -10,17 +8,16 @@ set(GEMM_QUANT_SRC gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp ) -if(DL_KERNELS) - list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp) - list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp) - list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp) - list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp) - list(APPEND GEMM_QUANT_SRC - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp - gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) -endif() + +list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp) +list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp) +list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp) +list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp) +list(APPEND GEMM_QUANT_SRC + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp + gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) add_instance_library(device_quantization_instance ${CONV2D_PERLAYER_QUANT_SRC} @@ -29,4 +26,3 @@ add_instance_library(device_quantization_instance ${CONV2D_BIAS_PERCHANNEL_QUANT_SRC} ${GEMM_QUANT_SRC} ) -endif() \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt index ba0197477..dbe376411 100644 --- a/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/softmax/CMakeLists.txt @@ -1,20 +1,17 @@ set(DEVICE_SOFTMAX_INSTANCES) -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f16_f16_instance_rank3_reduce1.cpp +list(APPEND DEVICE_SOFTMAX_INSTANCES + device_softmax_f16_f16_instance_rank3_reduce1.cpp device_softmax_f16_f16_instance_rank3_reduce2.cpp device_softmax_f16_f16_instance_rank3_reduce3.cpp device_softmax_f16_f16_instance_rank4_reduce1.cpp device_softmax_f16_f16_instance_rank4_reduce2.cpp device_softmax_f16_f16_instance_rank4_reduce3.cpp - device_softmax_f16_f16_instance_rank4_reduce4.cpp) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank3_reduce1.cpp + device_softmax_f16_f16_instance_rank4_reduce4.cpp + device_softmax_f32_f32_instance_rank3_reduce1.cpp device_softmax_f32_f32_instance_rank3_reduce2.cpp device_softmax_f32_f32_instance_rank3_reduce3.cpp device_softmax_f32_f32_instance_rank4_reduce1.cpp device_softmax_f32_f32_instance_rank4_reduce2.cpp device_softmax_f32_f32_instance_rank4_reduce3.cpp device_softmax_f32_f32_instance_rank4_reduce4.cpp) -endif() add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES}) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8fddd6085..368965aa8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -9,26 +9,121 @@ add_custom_target(tests) function(add_test_executable TEST_NAME) message("adding test ${TEST_NAME}") - add_executable(${TEST_NAME} ${ARGN}) - add_test(NAME ${TEST_NAME} COMMAND $) - add_dependencies(tests ${TEST_NAME}) - add_dependencies(check ${TEST_NAME}) - rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS ARGN) + set(test 0) + foreach(type IN LISTS DTYPES) + if(type MATCHES "fp16") + set(type1 "_f16") + elseif(type MATCHES "fp32") + set(type1 "_f32") + elseif(type MATCHES "fp8") + set(type1 "_f8") + elseif(type MATCHES "bf16") + set(type1 "_b16") + elseif(type MATCHES "fp64") + set(type1 "_f64") + elseif(type MATCHES "int8") + set(type1 "_i8") + endif() + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() + elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) + #if filename contains a type which doesn't match any selected type, mark it for removal + set(test 1) + endif() + endforeach() + if(test EQUAL 1) + message("removing test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS ARGN) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + + #only continue if there are some source files left on the list + if(ARGN) + add_executable(${TEST_NAME} ${ARGN}) + add_test(NAME ${TEST_NAME} COMMAND $) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) + set(result 0) + endif() + #message("add_test returns ${result}") + return(PROPAGATE result) endfunction(add_test_executable TEST_NAME) include(GoogleTest) function(add_gtest_executable TEST_NAME) message("adding gtest ${TEST_NAME}") - add_executable(${TEST_NAME} ${ARGN}) - add_dependencies(tests ${TEST_NAME}) - add_dependencies(check ${TEST_NAME}) + set(result 1) + if(DEFINED DTYPES) + foreach(source IN LISTS ARGN) + set(test 0) + foreach(type IN LISTS DTYPES) + if(type MATCHES "fp16") + set(type1 "_f16") + elseif(type MATCHES "fp32") + set(type1 "_f32") + elseif(type MATCHES "fp8") + set(type1 "_f8") + elseif(type MATCHES "bf16") + set(type1 "_b16") + elseif(type MATCHES "fp64") + set(type1 "_f64") + elseif(type MATCHES "int8") + set(type1 "_i8") + endif() + if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}") + #if filename matches any selected type, exit type loop and do no exclude the file from the list + set(test 0) + break() + elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR + source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND + NOT(source MATCHES type OR source MATCHES type1)) + #if filename contains a type which doesn't match any selected type, mark it for removal + set(test 1) + endif() + endforeach() + if(test EQUAL 1) + message("removing gtest ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + endif() + foreach(source IN LISTS ARGN) + if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + message("removing dl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + #only continue if there are some source files left on the list + if(ARGN) + add_executable(${TEST_NAME} ${ARGN}) + add_dependencies(tests ${TEST_NAME}) + add_dependencies(check ${TEST_NAME}) - # suppress gtest warnings - target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) - target_link_libraries(${TEST_NAME} PRIVATE gtest_main) - add_test(NAME ${TEST_NAME} COMMAND $) - rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) + # suppress gtest warnings + target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(${TEST_NAME} PRIVATE gtest_main) + add_test(NAME ${TEST_NAME} COMMAND $) + rocm_install(TARGETS ${TEST_NAME} COMPONENT tests) + set(result 0) + endif() + #message("add_gtest returns ${result}") + return(PROPAGATE result) endfunction(add_gtest_executable TEST_NAME) add_subdirectory(magic_number_division) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 8a3269e90..d169c135c 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -2,25 +2,21 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) - target_link_libraries(test_batched_gemm_fp16 PRIVATE utility) - target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance) endif() - if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) - target_link_libraries(test_batched_gemm_fp32 PRIVATE utility) - target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance) endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) - target_link_libraries(test_batched_gemm_bf16 PRIVATE utility) - target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance) endif() - if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) - add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) - target_link_libraries(test_batched_gemm_int8 PRIVATE utility) - target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance) + add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance) endif() set(target 1) endif() diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 404e74f3b..03f1d3a4e 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_custom_target(test_batched_gemm_gemm) - add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) + add_custom_target(test_batched_gemm_gemm) + add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) + if(result EQUAL 0) target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) set(target 1) - endif() + endif() endif() endforeach() \ No newline at end of file diff --git a/test/batched_gemm_multi_d/CMakeLists.txt b/test/batched_gemm_multi_d/CMakeLists.txt index 825f0dd23..d5e4c4fbe 100644 --- a/test/batched_gemm_multi_d/CMakeLists.txt +++ b/test/batched_gemm_multi_d/CMakeLists.txt @@ -1,4 +1,4 @@ -if(DL_KERNELS) - add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp) +add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d_dl.cpp) +if(result EQUAL 0) target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance) endif() diff --git a/test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp b/test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp similarity index 100% rename from test/batched_gemm_multi_d/test_batched_gemm_multi_d.cpp rename to test/batched_gemm_multi_d/test_batched_gemm_multi_d_dl.cpp diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index af95a50ea..32c6ee85d 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -2,10 +2,9 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) + add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) set(target 1) endif() endif() diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index c49175a2e..c011a6a3c 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_custom_target(test_batched_gemm_softmax_gemm) - add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) - target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) - add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) - set(target 1) - endif() + add_custom_target(test_batched_gemm_softmax_gemm) + add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) + add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) + set(target 1) + endif() endif() endforeach() \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index ae28daa80..3164863ee 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -2,25 +2,28 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_custom_target(test_batched_gemm_softmax_gemm_permute) - endif() - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) - endif() - if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) - endif() + add_custom_target(test_batched_gemm_softmax_gemm_permute) + add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) + endif() + add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) + endif() + + add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) + endif() + add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) + if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) + endif() set(target 1) endif() endforeach() \ No newline at end of file diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index baf017455..b1606d2a7 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -1,14 +1,15 @@ if (USE_BITINT_EXTENSION_INT4) add_gtest_executable(test_int4 int4.cpp) - target_link_libraries(test_int4 PRIVATE utility) + if(result EQUAL 0) + target_link_libraries(test_int4 PRIVATE utility) + endif() endif() -if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) - add_gtest_executable(test_f8 f8.cpp) - target_link_libraries(test_f8 PRIVATE utility) +add_gtest_executable(test_fp8 fp8.cpp) +if(result EQUAL 0) + target_link_libraries(test_fp8 PRIVATE utility) endif() - -if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) - add_gtest_executable(test_bf8 bf8.cpp) +add_gtest_executable(test_bf8 bf8.cpp) +if(result EQUAL 0) target_link_libraries(test_bf8 PRIVATE utility) endif() diff --git a/test/data_type/f8.cpp b/test/data_type/fp8.cpp similarity index 100% rename from test/data_type/f8.cpp rename to test/data_type/fp8.cpp diff --git a/test/elementwise_normalization/CMakeLists.txt b/test/elementwise_normalization/CMakeLists.txt index 74a3e4999..aed67901b 100644 --- a/test/elementwise_normalization/CMakeLists.txt +++ b/test/elementwise_normalization/CMakeLists.txt @@ -1,6 +1,6 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_custom_target(test_elementwise_normalization) - add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp) +add_custom_target(test_elementwise_normalization) +add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp) +if(result EQUAL 0) target_link_libraries(test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance) add_dependencies(test_elementwise_normalization test_elementwise_layernorm_fp16) endif() \ No newline at end of file diff --git a/test/gemm/CMakeLists.txt b/test/gemm/CMakeLists.txt index 7a8836bfe..f88a13404 100644 --- a/test/gemm/CMakeLists.txt +++ b/test/gemm/CMakeLists.txt @@ -1,30 +1,28 @@ -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) add_test_executable(test_gemm_fp32 gemm_fp32.cpp) -target_link_libraries(test_gemm_fp32 PRIVATE utility) -target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) +if(result EQUAL 0) + target_link_libraries(test_gemm_fp32 PRIVATE utility device_gemm_instance) endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_test_executable(test_gemm_fp16 gemm_fp16.cpp) -target_link_libraries(test_gemm_fp16 PRIVATE utility) -target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance) -add_library(gemm_standalone_xdl_fp16_instances STATIC +if(result EQUAL 0) + target_link_libraries(test_gemm_fp16 PRIVATE utility device_gemm_instance) + add_library(gemm_standalone_xdl_fp16_instances STATIC instance/gemm_f16_nn_instance.cpp instance/gemm_f16_nt_instance.cpp instance/gemm_f16_tn_instance.cpp instance/gemm_wavelet_f16_tn_instance.cpp instance/gemm_f16_tt_instance.cpp -) + ) +endif() add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp) -target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) -target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/) +if(result EQUAL 0) + target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility) + target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/) endif() -if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_test_executable(test_gemm_bf16 gemm_bf16.cpp) -target_link_libraries(test_gemm_bf16 PRIVATE utility) -target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance) +if(result EQUAL 0) + target_link_libraries(test_gemm_bf16 PRIVATE utility device_gemm_instance) endif() -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) add_test_executable(test_gemm_int8 gemm_int8.cpp) -target_link_libraries(test_gemm_int8 PRIVATE utility) -target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance) +if(result EQUAL 0) + target_link_libraries(test_gemm_int8 PRIVATE utility device_gemm_instance) endif() \ No newline at end of file diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index ba0a99b67..bfc4404bd 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list AND target EQUAL 0) - if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) add_custom_target(test_gemm_layernorm) add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) - target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) - add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) - set(target 1) - endif() + if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) + set(target 1) + endif() endif() endforeach() diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 43c8d6074..42a53c304 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,5 +1,4 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) - target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility) - target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance) +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) endif() \ No newline at end of file diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 476d953ed..8c57b667e 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,4 +1,3 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -13,4 +12,3 @@ foreach(gpu IN LISTS GPU_TARGETS) set(target 1) endif() endforeach() -endif() diff --git a/test/normalization/CMakeLists.txt b/test/normalization/CMakeLists.txt index 2beda4dd7..fe4b1ad82 100644 --- a/test/normalization/CMakeLists.txt +++ b/test/normalization/CMakeLists.txt @@ -1,19 +1,21 @@ -if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_custom_target(test_normalization) -endif() -if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES) - add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp) - add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) +add_custom_target(test_normalization) +add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp) +if(result EQUAL 0) target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance) - target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance) add_dependencies(test_normalization test_layernorm2d_fp32) +endif() +add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp) +if(result EQUAL 0) + target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance) add_dependencies(test_normalization test_groupnorm_fp32) endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp) - add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp) +add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp) +if(result EQUAL 0) target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance) - target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance) add_dependencies(test_normalization test_layernorm2d_fp16) +endif() +add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance) add_dependencies(test_normalization test_groupnorm_fp16) endif() diff --git a/test/reduce/CMakeLists.txt b/test/reduce/CMakeLists.txt index fb436165e..fabaf7cc5 100644 --- a/test/reduce/CMakeLists.txt +++ b/test/reduce/CMakeLists.txt @@ -1,7 +1,5 @@ add_test_executable(test_reduce_no_index reduce_no_index.cpp) add_test_executable(test_reduce_with_index reduce_with_index.cpp) -target_link_libraries(test_reduce_no_index PRIVATE utility) -target_link_libraries(test_reduce_no_index PRIVATE device_reduce_instance) -target_link_libraries(test_reduce_with_index PRIVATE utility) -target_link_libraries(test_reduce_with_index PRIVATE device_reduce_instance) +target_link_libraries(test_reduce_no_index PRIVATE utility device_reduce_instance) +target_link_libraries(test_reduce_with_index PRIVATE utility device_reduce_instance) -- GitLab From ede64ae9db8eb052329a7fa7560cc660cfd6be09 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 22 Sep 2023 10:08:45 -0500 Subject: [PATCH 026/112] Update naming (#937) --- .../gpu/gemm_splitk/CMakeLists.txt | 16 ++++++++-------- ...xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp} | 0 ...xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp} | 0 ...xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp} | 0 ...xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp} | 0 ...xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp} | 0 ...xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp} | 0 ...xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp} | 0 ...xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp} | 0 9 files changed, 8 insertions(+), 8 deletions(-) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp => device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp => device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp => device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp => device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp => device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp => device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp => device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp} (100%) rename library/src/tensor_operation_instance/gpu/gemm_splitk/{device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp => device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp} (100%) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 837e040ea..36c0b50e8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -8,13 +8,13 @@ list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_in device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp) + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp) add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp rename to library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp -- GitLab From c95538325b49a9a12c761f8783b0b0f8c3161f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 23 Sep 2023 18:56:31 +0200 Subject: [PATCH 027/112] Add 3d grouped conv fwd wmma instances (#935) * Add 3d grouped conv fwd wmma instances * Refactor fwd conv tests * Split wmma instances for each specialization * Minor stylistic fixes --- ...evice_grouped_conv2d_fwd_wmma_instance.hpp | 134 ----- .../device_grouped_conv_fwd_wmma_instance.hpp | 136 +++++ .../gpu/grouped_convolution_forward.hpp | 465 +++++++++++++++++- .../gpu/grouped_conv1d_fwd/CMakeLists.txt | 8 +- ...d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp | 0 ...1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp | 0 ...1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp | 0 ...d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp | 0 .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 40 +- ...wd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 66 --- ...fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp | 66 --- ..._fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 0 ..._fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 0 ..._fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 0 ..._fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 0 ...a_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp | 40 ++ ...gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp | 40 ++ ...wd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 40 ++ ...ma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp | 40 ++ ...ma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp | 40 ++ ..._gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp | 40 ++ ...fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp | 40 ++ ...mma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp | 40 ++ ...a_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp | 40 ++ ...nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp | 40 ++ ...wd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 40 ++ ...ma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp | 40 ++ ...ma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp | 40 ++ ..._nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp | 40 ++ ...fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp | 40 ++ ...mma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp | 40 ++ ...wd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 0 ...fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 2 +- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 0 ...wd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 0 ...fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 0 ...fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 0 .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 38 +- ...ndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp | 41 ++ ...hwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp | 41 ++ ...wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp | 41 ++ ...gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp | 41 ++ ...gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp | 41 ++ ...dhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp | 41 ++ ..._wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp | 41 ++ ..._gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp | 41 ++ ...dhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp | 41 ++ ...wgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp | 41 ++ ...wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 41 ++ ...ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp | 41 ++ ...ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp | 41 ++ ...hwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp | 41 ++ ..._wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp | 41 ++ ..._ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp | 41 ++ ...xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp | 0 ..._xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp | 0 ..._xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp | 0 ...xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp | 0 ...xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 0 ..._xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 0 ..._xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 0 ...xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp | 0 test/grouped_convnd_fwd/CMakeLists.txt | 2 +- .../grouped_convnd_fwd/grouped_convnd_fwd.cpp | 284 ----------- .../test_grouped_convnd_fwd.cpp | 133 +++++ 65 files changed, 2092 insertions(+), 578 deletions(-) delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp rename library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/{ => xdl}/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/{ => xdl}/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/{ => xdl}/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/{ => xdl}/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp (100%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => dl}/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => dl}/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => dl}/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => dl}/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp (100%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => xdl}/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => xdl}/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp (98%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => xdl}/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => xdl}/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => xdl}/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/{ => xdl}/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp (100%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp (100%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/{ => xdl}/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp (100%) delete mode 100644 test/grouped_convnd_fwd/grouped_convnd_fwd.cpp create mode 100644 test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp deleted file mode 100644 index c9cf0f8e1..000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp +++ /dev/null @@ -1,134 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using BF16 = ck::bhalf_t; -using F16 = ck::half_t; -using F32 = float; -using I8 = int8_t; -using I32 = int32_t; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NHWGC = ck::tensor_layout::convolution::NHWGC; -using GNHWC = ck::tensor_layout::convolution::GNHWC; - -using GKYXC = ck::tensor_layout::convolution::GKYXC; - -using NHWGK = ck::tensor_layout::convolution::NHWGK; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using device_grouped_conv2d_fwd_wmma_f16_instances = std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // blocksize=256 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // blocksize=128 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // blocksize=64 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - // blocksize=32 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; - -template -using device_grouped_conv2d_fwd_wmma_i8_instances = std::tuple< - // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // blocksize=256 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - // blocksize=128 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - // blocksize=64 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - // blocksize=32 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp new file mode 100644 index 000000000..f92539783 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using I8 = int8_t; +using I32 = int32_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_wmma_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_wmma_i8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //generic instance + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>, + // blocksize=256 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + // blocksize=128 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + // blocksize=64 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + // blocksize=32 + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 7fc88c33a..e6dbd349d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -170,6 +170,48 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); + #ifdef DL_KERNELS void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances); #endif #if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( @@ -234,6 +318,62 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( PassThrough>>>& instances); #endif #ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); #endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); +#endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( @@ -323,6 +576,62 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_BF16 // grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK @@ -354,6 +663,62 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( @@ -384,6 +749,62 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( PassThrough, PassThrough, PassThrough>>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances); #endif template ) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); } #endif } @@ -516,14 +943,28 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } +#endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); + } #endif } else if constexpr(NumDimSpatial == 3 && is_same_v && @@ -541,6 +982,12 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -555,6 +1002,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); } #endif } @@ -573,6 +1025,12 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -587,6 +1045,11 @@ struct DeviceOperationInstanceFactory) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(op_ptrs); } #endif } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index 1d90593e3..f51a484bb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -1,6 +1,6 @@ add_instance_library(device_grouped_conv1d_fwd_instance - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp - device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp + xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp + xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp + xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp + xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 6b0ad99d6..1542d611f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -1,21 +1,37 @@ add_instance_library(device_grouped_conv2d_fwd_instance #xdl # GNHWC, GKYXC, GNHWK - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp # NHWGC, GKYXC, NHWGK - device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp #dl # GNHWC, GKYXC, GNHWK - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + # NHWGC, GKYXC, NHWGK + dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp # WMMA - device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp - device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + # GNHWC, GKYXC, GNHWK + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp # NHWGC, GKYXC, NHWGK - device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp deleted file mode 100644 index 3904592fc..000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_f16_instances{}); - - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_f16_instances{}); - - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_f16_instances{}); - - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_f16_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp deleted file mode 100644 index 0cded93d6..000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_i8_instances{}); - - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_i8_instances{}); - - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_i8_instances{}); - - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_wmma_i8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp new file mode 100644 index 000000000..2aa6608c8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp new file mode 100644 index 000000000..f928d525a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp new file mode 100644 index 000000000..94cf43999 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp new file mode 100644 index 000000000..f519a1936 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp new file mode 100644 index 000000000..926b7227c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp new file mode 100644 index 000000000..b72c4cf5d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp new file mode 100644 index 000000000..108602e13 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp new file mode 100644 index 000000000..b50354287 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + GNHWC, + GKYXC, + Empty_Tuple, + GNHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp new file mode 100644 index 000000000..8316debe5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp new file mode 100644 index 000000000..9b90db14c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 000000000..cd69e5aa2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp new file mode 100644 index 000000000..78d7711a4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp new file mode 100644 index 000000000..193d35dc6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp new file mode 100644 index 000000000..c3a336127 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp new file mode 100644 index 000000000..3e57cd496 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp new file mode 100644 index 000000000..75f8ff032 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp similarity index 98% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index debdb05b4..cbdce0009 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -8,7 +8,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] +// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp new file mode 100644 index 000000000..05e42a6cd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp new file mode 100644 index 000000000..85138786b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp new file mode 100644 index 000000000..ae1931afa --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp new file mode 100644 index 000000000..fd2b8b2ec --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp new file mode 100644 index 000000000..4de33c357 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp new file mode 100644 index 000000000..9a6856de6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp new file mode 100644 index 000000000..ca40e163d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, +// wo, k] +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + GNDHWC, + GKZYXC, + Empty_Tuple, + GNDHWK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp new file mode 100644 index 000000000..4357f583b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp new file mode 100644 index 000000000..199747eb9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 000000000..1c3f20bd9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp new file mode 100644 index 000000000..a3cb03709 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_f16_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp new file mode 100644 index 000000000..721fa653f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp new file mode 100644 index 000000000..459af039d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp new file mode 100644 index 000000000..e02557e1a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp new file mode 100644 index 000000000..e48e3675a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = out[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_wmma_i8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + Empty_Tuple, + PassThrough, + ConvFwdOddC>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp similarity index 100% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 38da88473..869bd77d3 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,3 +1,3 @@ -add_gtest_executable(test_grouped_convnd_fwd grouped_convnd_fwd.cpp) +add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) diff --git a/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp deleted file mode 100644 index c856255ea..000000000 --- a/test/grouped_convnd_fwd/grouped_convnd_fwd.cpp +++ /dev/null @@ -1,284 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include - -#include "profiler/profile_grouped_conv_fwd_impl.hpp" - -class TestGroupedConvNdFwd : public ::testing::Test -{ - protected: - std::vector conv_params; -}; - -// 1d GNWC/GKXC/GNWK -TEST_F(TestGroupedConvNdFwd, GroupedConv1dFwdGNWC) -{ - conv_params.clear(); - conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); - conv_params.push_back({1, 2, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); - conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); - conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); - conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); - - for(auto& param : conv_params) - { - bool pass; - - // fp32 - pass = ck::profiler::profile_grouped_conv_fwd_impl<1, - ck::tensor_layout::convolution::GNWC, - ck::tensor_layout::convolution::GKXC, - ck::tensor_layout::convolution::GNWK, - float, - float, - float>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // fp16 - pass = ck::profiler::profile_grouped_conv_fwd_impl<1, - ck::tensor_layout::convolution::GNWC, - ck::tensor_layout::convolution::GKXC, - ck::tensor_layout::convolution::GNWK, - ck::half_t, - ck::half_t, - ck::half_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // bf16 - pass = ck::profiler::profile_grouped_conv_fwd_impl<1, - ck::tensor_layout::convolution::GNWC, - ck::tensor_layout::convolution::GKXC, - ck::tensor_layout::convolution::GNWK, - ck::bhalf_t, - ck::bhalf_t, - ck::bhalf_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // int8 - pass = ck::profiler::profile_grouped_conv_fwd_impl<1, - ck::tensor_layout::convolution::GNWC, - ck::tensor_layout::convolution::GKXC, - ck::tensor_layout::convolution::GNWK, - int8_t, - int8_t, - int8_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - } -} - -// 2d GNHWC/GKYXC/GNHWK -TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdGNHWC) -{ - conv_params.clear(); - conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); - conv_params.push_back({2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - - for(auto& param : conv_params) - { - bool pass; - - // fp32 - pass = ck::profiler::profile_grouped_conv_fwd_impl<2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::tensor_layout::convolution::GNHWK, - float, - float, - float>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // fp16 - pass = ck::profiler::profile_grouped_conv_fwd_impl<2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::tensor_layout::convolution::GNHWK, - ck::half_t, - ck::half_t, - ck::half_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // bf16 - pass = ck::profiler::profile_grouped_conv_fwd_impl<2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::tensor_layout::convolution::GNHWK, - ck::bhalf_t, - ck::bhalf_t, - ck::bhalf_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // int8 - pass = ck::profiler::profile_grouped_conv_fwd_impl<2, - ck::tensor_layout::convolution::GNHWC, - ck::tensor_layout::convolution::GKYXC, - ck::tensor_layout::convolution::GNHWK, - int8_t, - int8_t, - int8_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - } -} - -// 3d GNDHWC/GKZYXC/GNDHWK -TEST_F(TestGroupedConvNdFwd, GroupedConv3dFwdGNDHWC) -{ - conv_params.clear(); - conv_params.push_back( - {3, 2, 128, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - conv_params.push_back( - {3, 2, 128, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - conv_params.push_back( - {3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); - conv_params.push_back( - {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - this->conv_params.push_back( - {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - conv_params.push_back( - {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); - - for(auto& param : conv_params) - { - bool pass; - - // fp32 - pass = ck::profiler::profile_grouped_conv_fwd_impl<3, - ck::tensor_layout::convolution::GNDHWC, - ck::tensor_layout::convolution::GKZYXC, - ck::tensor_layout::convolution::GNDHWK, - float, - float, - float>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // fp16 - pass = ck::profiler::profile_grouped_conv_fwd_impl<3, - ck::tensor_layout::convolution::GNDHWC, - ck::tensor_layout::convolution::GKZYXC, - ck::tensor_layout::convolution::GNDHWK, - ck::half_t, - ck::half_t, - ck::half_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // bf16 - pass = ck::profiler::profile_grouped_conv_fwd_impl<3, - ck::tensor_layout::convolution::GNDHWC, - ck::tensor_layout::convolution::GKZYXC, - ck::tensor_layout::convolution::GNDHWK, - ck::bhalf_t, - ck::bhalf_t, - ck::bhalf_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - - // int8 - pass = ck::profiler::profile_grouped_conv_fwd_impl<3, - ck::tensor_layout::convolution::GNDHWC, - ck::tensor_layout::convolution::GKZYXC, - ck::tensor_layout::convolution::GNDHWK, - int8_t, - int8_t, - int8_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - } -} - -// 2d NHWGC/KYXGC/NHWGK -TEST_F(TestGroupedConvNdFwd, GroupedConv2dFwdNHWGC) -{ - conv_params.clear(); - conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); - conv_params.push_back({2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - conv_params.push_back({2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); - conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); - - for(auto& param : conv_params) - { - bool pass; - - // fp16 - pass = ck::profiler::profile_grouped_conv_fwd_impl<2, - ck::tensor_layout::convolution::NHWGC, - ck::tensor_layout::convolution::GKYXC, - ck::tensor_layout::convolution::NHWGK, - ck::half_t, - ck::half_t, - ck::half_t>(true, // do_verification - 1, // init_method - false, // do_log - false, // time_kernel - param); - - EXPECT_TRUE(pass); - } -} diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp new file mode 100644 index 000000000..dde8313f9 --- /dev/null +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp @@ -0,0 +1,133 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_impl.hpp" + +template +class TestGroupedConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_fwd_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes1d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwd1d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwd2d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwd3d : public TestGroupedConvndFwd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwd1d, KernelTypes1d); +TYPED_TEST_SUITE(TestGroupedConvndFwd2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwd3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwd1d, Test1D) +{ + this->conv_params.clear(); + this->conv_params.push_back({1, 2, 32, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 2, 32, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); + this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}}); + this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}}); + this->template Run<1>(); +} + +TYPED_TEST(TestGroupedConvndFwd2d, Test2D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back( + {2, 2, 32, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}); + this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwd3d, Test3D) +{ + this->conv_params.clear(); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}); + this->conv_params.push_back( + {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->conv_params.push_back( + {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); + this->template Run<3>(); +} -- GitLab From 2ea75bd6d7f0a061f78e4bb007a840a394a74ba9 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 26 Sep 2023 08:39:11 -0700 Subject: [PATCH 028/112] Resolve some data type issues and cmake policy. (#940) * split the types in gemm_bilinear instances, add condition to cmake policy * fix syntax * split the data types in batchnorm examples * fix the batchnorm_bwd test * fix types in the batchnorm_bwd test --- CMakeLists.txt | 5 ++++- .../gpu/gemm_bilinear.hpp | 20 ++++++++++-------- test/batchnorm/batchnorm_bwd_rank_4.cpp | 21 +++++++++++++++---- test/batchnorm/batchnorm_fwd_rank_4.cpp | 21 +++++++++++++++---- test/batchnorm/batchnorm_infer_rank_4.cpp | 21 +++++++++++++++---- 5 files changed, 66 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 95ef2df7e..e5c82b970 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,8 @@ cmake_minimum_required(VERSION 3.14) -cmake_policy(SET CMP0140 NEW) +if(POLICY CMP0140) + # policies CMP0140 not known to CMake until 3.25 + cmake_policy(SET CMP0140 NEW) +endif() # This has to be initialized before the project() command appears # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 387499e58..1a518a530 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -11,12 +11,12 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#ifdef CK_ENABLE_FP16 + namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP16 void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector>>& instances); - +#endif // GEMM + Bilinear template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes); diff --git a/test/batchnorm/batchnorm_fwd_rank_4.cpp b/test/batchnorm/batchnorm_fwd_rank_4.cpp index 9b6fbd0f6..6bf635f0c 100644 --- a/test/batchnorm/batchnorm_fwd_rank_4.cpp +++ b/test/batchnorm/batchnorm_fwd_rank_4.cpp @@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test } }; -using KernelTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes); diff --git a/test/batchnorm/batchnorm_infer_rank_4.cpp b/test/batchnorm/batchnorm_infer_rank_4.cpp index ecb4043b3..0165192ac 100644 --- a/test/batchnorm/batchnorm_infer_rank_4.cpp +++ b/test/batchnorm/batchnorm_infer_rank_4.cpp @@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test } }; -using KernelTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormInferRank4, KernelTypes); -- GitLab From 0b296a2722e6769ddaadd93f305086f3928f28fb Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 26 Sep 2023 11:22:31 -0700 Subject: [PATCH 029/112] split ckProfiler gfx9 package into gfx90 and gfx94 (#946) --- CMakeLists.txt | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e5c82b970..7b8e43b9d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -111,16 +111,18 @@ else() add_definitions(-DPROFILER_ONLY) set(GPU_TARGETS "" CACHE STRING "" FORCE) if(GPU_TARGETS) - message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, or gfx11") endif() - if(GPU_ARCH MATCHES "gfx9") - rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942") + if(GPU_ARCH MATCHES "gfx90") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx900;gfx906;gfx908;gfx90a") + elseif(GPU_ARCH MATCHES "gfx94") + rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942") elseif(GPU_ARCH MATCHES "gfx10") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030") elseif(GPU_ARCH MATCHES "gfx11") rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102") else() - message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx9, gfx10, or gfx11") + message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, or gfx11") endif() set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE) endif() -- GitLab From 94bfa50256166d7fec4dd59c90e25d1a81b6531b Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:59:33 -0500 Subject: [PATCH 030/112] Add fp8 gemm instances (#920) * Add fp8 gemm instances * Update instance naming --- .../tensor_operation_instance/gpu/gemm.hpp | 43 ++++++++++++ .../gpu/gemm/CMakeLists.txt | 5 ++ ..._shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp | 69 +++++++++++++++++++ ..._shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp | 69 +++++++++++++++++++ ..._shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp | 69 +++++++++++++++++++ ..._shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp | 66 ++++++++++++++++++ .../include/profiler/profile_gemm_impl.hpp | 6 ++ profiler/src/profile_gemm.cpp | 24 ++++++- 8 files changed, 350 insertions(+), 1 deletion(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index 3ceea58c1..7128a10fa 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -312,6 +312,23 @@ void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( DeviceGemm>>& instances); #endif +#ifdef CK_ENABLE_FP8 +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( + std::vector>>& instances); +#endif template && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(op_ptrs); + } + } #endif return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 22cf1a462..89820775c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -89,6 +89,11 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_ins device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp) +list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp + device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) + add_instance_library(device_gemm_instance ${GEMM_INSTANCES}) set(ENABLE_PIPELINE_V2_OPT OFF) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp new file mode 100644 index 000000000..9739046d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp new file mode 100644 index 000000000..810c1b87c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 4, 16, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 4, 16, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 4, 16, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..e1aac4513 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_kn_mn_instance.cpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp new file mode 100644 index 000000000..5c557312a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_fp8_fp8_fp8_mk_nk_mn_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#ifdef CK_ENABLE_FP8 +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; +using F8 = f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = + std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 256, 64, 16, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 64, 128, 64, 16, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 64, 64, 16, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 64, 16, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 128, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/profiler/include/profiler/profile_gemm_impl.hpp b/profiler/include/profiler/profile_gemm_impl.hpp index eaab5dbcc..ccebb020c 100644 --- a/profiler/include/profiler/profile_gemm_impl.hpp +++ b/profiler/include/profiler/profile_gemm_impl.hpp @@ -223,6 +223,12 @@ int profile_gemm_impl(int do_verification, { std::cout << "Best Perf for datatype = int8"; } +#if defined CK_ENABLE_FP8 + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = fp8"; + } +#endif if constexpr(is_same::value) { diff --git a/profiler/src/profile_gemm.cpp b/profiler/src/profile_gemm.cpp index 9ca7fc4c8..df243c96d 100644 --- a/profiler/src/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -23,6 +23,7 @@ enum struct GemmDataType F16_F16_F16, // 1 BF16_BF16_BF16, // 2 INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 }; #define OP_NAME "gemm" @@ -31,7 +32,7 @@ enum struct GemmDataType static void print_helper_msg() { std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" - << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n" + << "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: fp8)\n" << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" << " 1: A[m, k] * B[n, k] = C[m, n];\n" << " 2: A[k, m] * B[k, n] = C[m, n];\n" @@ -76,6 +77,9 @@ int profile_gemm(int argc, char* argv[]) using INT8 = int8_t; using INT32 = int32_t; #endif +#ifdef CK_ENABLE_FP8 + using F8 = ck::f8_t; +#endif using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -194,6 +198,24 @@ int profile_gemm(int argc, char* argv[]) { return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{}); } +#endif +#ifdef CK_ENABLE_FP8 + else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}, F8{}, F8{}, F32{}, F8{}); + } + else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}, F8{}, F8{}, F32{}, F8{}); + } + else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}, F8{}, F8{}, F32{}, F8{}); + } + else if(data_type == GemmDataType::F8_F8_F8 && layout == GemmMatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}, F8{}, F8{}, F32{}, F8{}); + } #endif else { -- GitLab From 48ba6e8a69947c67b0b4e5021178fe3e2f2c638e Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 26 Sep 2023 18:40:00 -0500 Subject: [PATCH 031/112] Fixed Gemmv2r3 kpad (#938) * added kpad support into v2r3 * add generic instances * fixed comments * fixed mnk padding * Update device_batched_gemm_xdl.hpp * fixed kpad --------- Co-authored-by: Jing Zhang --- .../gpu/grid/gridwise_gemm_pipeline_v2.hpp | 4 ++-- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index d3d7d5af8..25e1cebdb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -9,13 +9,13 @@ namespace ck { struct GridwiseGemmPipeline_v2 { - __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + __host__ __device__ static constexpr bool IsSupported(const index_t num_loop) { // TODO: improve applicability return num_loop % 2 == 0; } - __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop) { return (num_loop / 2) > 1; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index b6c146ae6..e941f9655 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; } - __host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); } // Argument struct Problem @@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 "Invalid tuning param!"); // check gridwise gemm pipeline - const index_t K0 = problem.K / K1Value; - const auto num_k_loop = K0 / K0PerBlock; - + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; @@ -1026,8 +1024,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } // check gridwise gemm pipeline - const index_t K0 = problem.K / K1; - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { -- GitLab From 420b5a038270b9107c4663dc61de4a9e5b4dd99c Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:43:09 -0700 Subject: [PATCH 032/112] Use lower case for ckprofiler package. (#948) * split ckProfiler gfx9 package into gfx90 and gfx94 * use lower case for package names --- CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b8e43b9d..f7a9d37bb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -446,14 +446,14 @@ if(NOT DEFINED INSTANCES_ONLY) rocm_package_setup_component(profiler LIBRARY_NAME composablekernel - PACKAGE_NAME ckProfiler + PACKAGE_NAME ckprofiler ) add_subdirectory(profiler) else() #When building PROFILER_ONLY, label the package with GPU_ARCH rocm_package_setup_component(profiler LIBRARY_NAME composablekernel - PACKAGE_NAME ckProfiler_${GPU_ARCH} + PACKAGE_NAME ckprofiler_${GPU_ARCH} ) add_subdirectory(profiler) endif() -- GitLab From 11676c7e49110b917734b164a070879ad67ea35d Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 26 Sep 2023 21:16:23 -0500 Subject: [PATCH 033/112] Add multiple A/B support (#906) * add gridwise_multi_abd * move element_op into RunRead * merge element_wise op with data read * add multiABD example * allow packed elementwise_op * changed example * clean * clean * add is_detected * fix * minor fix * add scaleAdd_vec4 example --------- Co-authored-by: Jing Zhang --- example/60_gemm_multiABD/CMakeLists.txt | 10 + .../gemm_multiABD_xdl_fp16.cpp | 361 ++++++ ...hread_group_tensor_slice_transfer_v7r2.hpp | 214 ++++ .../gpu/device/device_gemm_multiple_abd.hpp | 60 + .../device_gemm_multiple_abd_xdl_cshuffle.hpp | 766 ++++++++++++ ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 1033 +++++++++++++++++ .../threadwise_tensor_slice_transfer_v7r2.hpp | 420 +++++++ include/ck/utility/is_detected.hpp | 34 + include/ck/utility/tuple.hpp | 2 + 9 files changed, 2900 insertions(+) create mode 100644 example/60_gemm_multiABD/CMakeLists.txt create mode 100644 example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp create mode 100644 include/ck/utility/is_detected.hpp diff --git a/example/60_gemm_multiABD/CMakeLists.txt b/example/60_gemm_multiABD/CMakeLists.txt new file mode 100644 index 000000000..9e2f70649 --- /dev/null +++ b/example/60_gemm_multiABD/CMakeLists.txt @@ -0,0 +1,10 @@ +if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) +list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list2 AND target EQUAL 0) + add_example_executable(example_gemm_multiABD_xdl_fp16 gemm_multiABD_xdl_fp16.cpp) + set(target 1) + endif() +endforeach() +endif() diff --git a/example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp b/example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp new file mode 100644 index 000000000..23d41d7cc --- /dev/null +++ b/example/60_gemm_multiABD/gemm_multiABD_xdl_fp16.cpp @@ -0,0 +1,361 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using ELayout = Row; + +struct AddScale +{ + static constexpr auto I0 = ck::Number<0>{}; + static constexpr auto I1 = ck::Number<1>{}; + static constexpr auto I2 = ck::Number<2>{}; + static constexpr auto I3 = ck::Number<3>{}; + + __host__ __device__ constexpr void + operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const + { + const auto a0_v_t = ck::vector_type{a0}; + const auto a1_v_t = ck::vector_type{a1}; + + auto r_v_t = ck::vector_type{}; + + r_v_t.AsType()(I0) = + scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); + r_v_t.AsType()(I1) = + scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); + r_v_t.AsType()(I2) = + scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); + r_v_t.AsType()(I3) = + scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); + + a = r_v_t.AsType()[I0]; + } + + __host__ __device__ constexpr void + operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const + { + a = scale * (a0 + a1); + } + + static constexpr ck::index_t vec_len = 4; + + float scale = 1.0; +}; + +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const ck::half_t& d) const + { + e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); + }; + + float alpha_; + float beta_; +}; + +using AElementOp = AddScale; +using BElementOp = PassThrough; +using CDEElementOp = AlphaBetaAdd; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle< + ck::Tuple, + ck::Tuple, + ck::Tuple, + ELayout, + ck::Tuple, + ck::Tuple, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 1, + 256, + 256, + 128, + 32, + 8, + 8, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + S<4, 64, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + 1, + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideD = 4096; + ck::index_t StrideE = 4096; + + float alpha = 1.0f; + float beta = 1.0f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 6) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + alpha = std::stof(argv[4]); + beta = std::stof(argv[5]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideD = std::stoi(argv[9]); + StrideE = std::stoi(argv[10]); + + alpha = std::stof(argv[11]); + beta = std::stof(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " + "beta\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + d_device_buf.ToDevice(d_m_n.mData.data()); + e_device_buf.ToDevice(e_m_n_device_result.mData.data()); + + auto a_element_op = AElementOp{0.2}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}, + std::array{b_device_buf.GetDeviceBuffer()}, + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + std::array{StrideA, StrideA}, + std::array{StrideB}, + std::array{StrideD}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + Tensor c_m_n({M, N}); + + Tensor a_m_k({M, K}); + + for(int m = 0; m < M; ++m) + { + for(int k = 0; k < K; ++k) + { + a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 000000000..1a9bb3213 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7r2 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs)); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp new file mode 100644 index 000000000..cbb9fadc6 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M, K], B0[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 000000000..827a341a5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,766 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = as_grid_desc_ak0_m_ak1; + ignore = bs_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD +{ + using DeviceOp = DeviceGemmMultipleABD_Xdl_CShuffle; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + +#if 0 + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideAs) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideAs, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideAs)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideBs) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideBs)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideBs, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } +#endif + using ComputeDataType = EDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + // desc for problem definition + using AsGridDesc_M_K = + remove_cvref_t( + {}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t( + {}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t( + {}, {}, {}))>; + using EGridDesc_M_N = + decltype(GridwiseGemm::template MakeEGridDescriptor_M_N(1, 1, 1)); + + // desc for blockwise copy + using AsGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BsGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::array p_as_grid, + std::array p_bs_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + as_grid_desc_m_k_{}, + bs_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{GridwiseGemm::template MakeEGridDescriptor_M_N( + MRaw, NRaw, StrideE)}, + as_grid_desc_ak0_m_ak1_{}, + bs_grid_desc_bk0_n_bk1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout = remove_cvref_t>; + using ADataType = remove_cvref_t>; + + // A pointer + p_as_grid_(i) = static_cast(p_as_grid[i]); + + // A desc + as_grid_desc_m_k_(i) = + GridwiseGemm::template MakeAGridDescriptor_M_K( + MRaw, KRaw, StrideAs[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + // B pointer + p_bs_grid_(i) = static_cast(p_bs_grid[i]); + + // B desc + bs_grid_desc_n_k_(i) = + GridwiseGemm::template MakeBGridDescriptor_N_K( + KRaw, NRaw, StrideBs[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + GridwiseGemm::template MakeEGridDescriptor_M_N( + MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, + bs_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + as_grid_desc_ak0_m_ak1_ = + GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + + bs_grid_desc_bk0_n_bk1_ = + GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + // std::cout << "A[M, K]: " << as_grid_desc_m_k_ << std::endl; + // std::cout << "B[N, K]: " << bs_grid_desc_n_k_ << std::endl; + // static_for<0, NumDTensor, 1>{}( + //[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + // std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AsGridDesc_M_K as_grid_desc_m_k_; + BsGridDesc_N_K bs_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; + BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_abd_xdl_cshuffle< + GridwiseGemm, + typename GridwiseGemm::AsGridPointer, + typename GridwiseGemm::BsGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AsGridDesc_AK0_M_AK1, + DeviceOp::BsGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.as_grid_desc_ak0_m_ak1_, + arg.bs_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + bool all_valid = true; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout = remove_cvref_t>; + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else + { + all_valid = false; + } + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout = remove_cvref_t>; + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + all_valid = false; + } + } + else + { + all_valid = false; + } + }); + + // check vector load of Ds + // only support RowMajor for now + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100644 index 000000000..d98cbea16 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1033 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +// GEMM: +// input : A0[M, K], A1[M, K] +// input : B0[N, K], B1[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleABD_xdl_cshuffle +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_WORKAROUND_DENORM_FIX + using ComputeDataType = + conditional_t, ck::bhalf_t, ComputeDataType_>; +#else + using ComputeDataType = ComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) + { + return generate_tuple( + [&](auto i) { return MakeAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, + Number{}); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) + { + return generate_tuple( + [&](auto i) { return MakeBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, + Number{}); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k, + const BsGridDesc_N_K& bs_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + const auto M = as_grid_desc_m_k[I0].GetLength(I0); + const auto N = bs_grid_desc_n_k[I0].GetLength(I0); + const auto AK = as_grid_desc_m_k[I0].GetLength(I1); + const auto BK = bs_grid_desc_n_k[I0].GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + bool valid = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + valid = + valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB); + valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) && + AK == as_grid_desc_m_k[i].GetLength(I1)); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + valid = + valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB); + valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) && + BK == bs_grid_desc_n_k[i].GetLength(I1)); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = AK / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + + if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride) + { + return generate_tuple( + [&](auto i) { + using ALayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], KRaws[i], AsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array& KRaws, + const std::array& NRaws, + const std::array& BsStride) + { + return generate_tuple( + [&](auto i) { + using BLayout = remove_cvref_t>; + + return MakeBGridDescriptor_N_K(KRaws[i], NRaws[i], BsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr index_t KPack = + math::max(math::lcm(AK1, BK1), + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + gridwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_buf, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const std::array StrideAs, + const std::array StrideBs, + const std::array StrideDs, + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + using AsGridDesc_M_K = + remove_cvref_t({}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t({}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + AsGridDesc_M_K as_grid_desc_m_k; + BsGridDesc_N_K bs_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumATensor, 1>{}([&](auto j) { + using ALayout = remove_cvref_t>; + + as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K(M, K, StrideAs[j]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto j) { + using BLayout = remove_cvref_t>; + + bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K(N, K, StrideBs[j]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); + + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp new file mode 100644 index 000000000..003a2b96d --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,420 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7r2 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve>; + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using DstSpaceFillingCurve = SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + template + using has_vec_len = decltype(std::declval().vec_len); + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) + { + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto dst_vectors = generate_vectors(); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + if constexpr(is_detected::value) + { + constexpr auto elem_op_vec_len = decltype(element_op_)::vec_len; + + static_assert(is_same, index_t>::value, + "vec_len in element_op_ type is not index_t"); + + static_assert(elem_op_vec_len == 1 || elem_op_vec_len == 2 || + elem_op_vec_len == 4 || elem_op_vec_len == 8, + "vec_len in element_op_ must be 1, 2, 4, 8"); + + static_assert(SrcScalarPerVector % elem_op_vec_len == 0, + "vec_len in element_op_ cannot be divided by SrcScalarPerVector!"); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = + typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = + typename vector_type::type; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } + else + { + // apply pointwise function + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } + + dst_vectors_tuple_(iAccess) = dst_vectors; + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) + { + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + + StaticallyIndexedArray dst_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp new file mode 100644 index 000000000..31a4bf24d --- /dev/null +++ b/include/ck/utility/is_detected.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +namespace detail { +template class Op, class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; +} // namespace detail + +struct nonesuch +{ + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template