From 8e97e85ac6cdb71903d3ac46a7e82f8350eb0ce5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Mar 2024 08:21:14 -0700 Subject: [PATCH 01/63] Bump rocm-docs-core from 0.35.1 to 0.36.0 in /docs/sphinx (#1194) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.35.1 to 0.36.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.35.1...v0.36.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 93c15a216..b3c826773 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.35.1 +rocm-docs-core==0.36.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 8faeac85d..ba1d7da44 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.35.1 +rocm-docs-core==0.36.0 # via -r requirements.in six==1.16.0 # via -- GitLab From 12441af014d8c865a0254efea3f82a07bdf58b4f Mon Sep 17 00:00:00 2001 From: randyh62 <42045079+randyh62@users.noreply.github.com> Date: Tue, 12 Mar 2024 18:25:48 -0700 Subject: [PATCH 02/63] Doc reorg2 (#1189) * doc_reorg2 updated TOC * doc_reorg2 updates * fix conflicts, add grid --- docs/{ => conceptual}/what-is-ck.rst | 4 +-- docs/index.rst | 25 +++++++------- docs/{ => install}/dockerhub.rst | 0 docs/license.md | 2 -- docs/license.rst | 11 +++++++ docs/{ => reference}/API_Reference_Guide.rst | 0 .../Supported_Primitives_Guide.rst | 0 docs/{ => reference}/wrapper.rst | 0 docs/sphinx/_toc.yml.in | 33 ++++++++++++++----- docs/{ => tutorial}/tutorial_hello_world.rst | 0 10 files changed, 49 insertions(+), 26 deletions(-) rename docs/{ => conceptual}/what-is-ck.rst (94%) rename docs/{ => install}/dockerhub.rst (100%) delete mode 100644 docs/license.md create mode 100644 docs/license.rst rename docs/{ => reference}/API_Reference_Guide.rst (100%) rename docs/{ => reference}/Supported_Primitives_Guide.rst (100%) rename docs/{ => reference}/wrapper.rst (100%) rename docs/{ => tutorial}/tutorial_hello_world.rst (100%) diff --git a/docs/what-is-ck.rst b/docs/conceptual/what-is-ck.rst similarity index 94% rename from docs/what-is-ck.rst rename to docs/conceptual/what-is-ck.rst index f0b51c48f..36785fc6c 100644 --- a/docs/what-is-ck.rst +++ b/docs/conceptual/what-is-ck.rst @@ -20,7 +20,7 @@ CK utilizes two concepts to achieve performance portability and code maintainabi * Algorithm complexity reduction for complex ML operators using an innovative technique called "Tensor Coordinate Transformation". -.. image:: data/ck_component.png +.. image:: ../data/ck_component.png :alt: CK Components @@ -36,6 +36,6 @@ The CK library is structured into 4 layers: It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. -.. image:: data/ck_layer.png +.. image:: ../data/ck_layer.png :alt: CK Layers \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 8ae4ce3a2..55c80b8ed 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,27 +12,26 @@ The Composable Kernel (CK) library provides a programming model for writing perf The CK documentation is structured as follows: -.. card:: Conceptual +.. grid:: 2 + :gutter: 3 - * :ref:`what-is-ck` + .. grid-item-card:: Installation -.. card:: Installation + * :ref:`docker-hub` - * :ref:`docker-hub` + .. grid-item-card:: Conceptual -.. card:: Tutorial + * :ref:`what-is-ck` - * :ref:`hello-world` + .. grid-item-card:: API reference -.. card:: API reference + * :ref:`supported-primitives` + * :ref:`api-reference` + * :ref:`wrapper` - * :ref:`supported-primitives` - * :ref:`api-reference` - * :ref:`wrapper` + .. grid-item-card:: Tutorial -.. card:: Contributing to CK - - * :ref:`contributing-to` + * :ref:`hello-world` To contribute to the documentation refer to `Contributing to ROCm `_. diff --git a/docs/dockerhub.rst b/docs/install/dockerhub.rst similarity index 100% rename from docs/dockerhub.rst rename to docs/install/dockerhub.rst diff --git a/docs/license.md b/docs/license.md deleted file mode 100644 index 43e471da0..000000000 --- a/docs/license.md +++ /dev/null @@ -1,2 +0,0 @@ -```{include} ../LICENSE.md -``` diff --git a/docs/license.rst b/docs/license.rst new file mode 100644 index 000000000..1e5389ccc --- /dev/null +++ b/docs/license.rst @@ -0,0 +1,11 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _license: + +******************************************************************** +License +******************************************************************** + +.. include:: ../LICENSE \ No newline at end of file diff --git a/docs/API_Reference_Guide.rst b/docs/reference/API_Reference_Guide.rst similarity index 100% rename from docs/API_Reference_Guide.rst rename to docs/reference/API_Reference_Guide.rst diff --git a/docs/Supported_Primitives_Guide.rst b/docs/reference/Supported_Primitives_Guide.rst similarity index 100% rename from docs/Supported_Primitives_Guide.rst rename to docs/reference/Supported_Primitives_Guide.rst diff --git a/docs/wrapper.rst b/docs/reference/wrapper.rst similarity index 100% rename from docs/wrapper.rst rename to docs/reference/wrapper.rst diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 578067462..533b81cd3 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -2,20 +2,35 @@ defaults: numbered: False root: index subtrees: -- entries: - - file: what-is-ck.rst + +- caption: Conceptual + entries: + - file: conceptual/what-is-ck.rst title: What is Composable Kernel? - - file: dockerhub.rst + +- caption: Install + entries: + - file: install/dockerhub.rst title: Docker Hub - - file: tutorial_hello_world.rst - title: Hello World Tutorial - - file: Supported_Primitives_Guide.rst + +- caption: CK API Reference + entries: + - file: reference/Supported_Primitives_Guide.rst title: Supported Primitives - - file: API_Reference_Guide.rst + - file: reference/API_Reference_Guide.rst title: API Reference - - file: wrapper.rst + - file: reference/wrapper.rst title: Wrapper + +- caption: Tutorial + entries: + - file: tutorial/tutorial_hello_world.rst + title: Hello World Tutorial + +- caption: About + entries: - file: Contributors_Guide.rst title: Contributing to CK - - file: license.md + - file: license.rst title: License + \ No newline at end of file diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial/tutorial_hello_world.rst similarity index 100% rename from docs/tutorial_hello_world.rst rename to docs/tutorial/tutorial_hello_world.rst -- GitLab From 285251768e8026689411d330def1aa6a2329b544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 13 Mar 2024 23:09:08 +0100 Subject: [PATCH 03/63] Add conv fwd/bwd data scale instances, extend bilinear instances (#1178) * Add conv fwd/bwd data scale instances * Fix cmake client example file --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- .../24_grouped_conv_activation/CMakeLists.txt | 8 + .../grouped_conv_bwd_data_scale_fp16.cpp | 216 +++++++++++++++++ .../grouped_conv_fwd_scale_fp16.cpp | 220 ++++++++++++++++++ .../element/unary_element_wise_operation.hpp | 6 + ...ed_conv_bwd_data_xdl_bilinear_instance.hpp | 137 ++++++----- ...ouped_conv_bwd_data_xdl_scale_instance.hpp | 149 ++++++++++++ ...grouped_conv_fwd_xdl_bilinear_instance.hpp | 120 +++++++--- ...ce_grouped_conv_fwd_xdl_scale_instance.hpp | 179 ++++++++++++++ ...rouped_convolution_backward_data_scale.hpp | 150 ++++++++++++ .../gpu/grouped_convolution_forward_scale.hpp | 175 ++++++++++++++ .../CMakeLists.txt | 6 + ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 50 ++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 50 ++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 50 ++++ .../grouped_conv3d_fwd_scale/CMakeLists.txt | 7 + ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 55 +++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 54 +++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 54 +++++ ...ale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp | 54 +++++ 19 files changed, 1644 insertions(+), 96 deletions(-) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index b4895db89..074dcd9b9 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -38,3 +38,11 @@ target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE c add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) target_link_libraries(client_grouped_convnd_bwd_data_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Fwd scale +add_executable(client_grouped_convnd_fwd_scale_fp16 + grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +# Bwd data scale +add_executable(client_grouped_convnd_bwd_data_scale_fp16 + grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp new file mode 100644 index 000000000..e53ecc6c9 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_data_scale() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + + 3 * G * N * Di * Hi * Wi * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_data_scale(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp new file mode 100644 index 000000000..11e69f5bb --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scale() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd_scale(); } diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c6d933893..9c64ad4df 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -310,6 +310,12 @@ struct Scale y = scale_ * x; }; + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = ck::type_convert(scale_ * ck::type_convert(x)); + }; + float scale_; }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp index 93a1ef209..216b4e2fe 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp @@ -18,8 +18,6 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; -using BF8 = ck::bf8_t; -using F8 = ck::f8_t; template using S = ck::Sequence; @@ -35,27 +33,42 @@ static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; // f16_f16_f32_f16 + template -using device_grouped_conv_bwd_data_xdl_bilinear_f16_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +using device_grouped_conv_bwd_data_xdl_bilinear_f16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> - // clang-format on - >; + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; // bf16_bf16_f32_bf16 template using device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances = std::tuple< // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -87,44 +113,35 @@ template -using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; - -// f16_f16_f16_comp_f8 -template -using device_grouped_conv_bwd_data_xdl_bilinear_input_fp16_comp_bf8f8_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, - // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8> - // clang-format on - >; + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp new file mode 100644 index 000000000..d278b9a48 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// f16_f16_f32_f16 + +template +using device_grouped_conv_bwd_data_xdl_scale_f16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +// bf16_bf16_f32_bf16 +template +using device_grouped_conv_bwd_data_xdl_scale_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +// f32_f32_f32_f32 +template +using device_grouped_conv_bwd_data_xdl_scale_f32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp index 3c689990a..1c3bfef8c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp @@ -45,17 +45,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_bf16_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -67,17 +79,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_f16_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -89,17 +113,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; @@ -111,17 +147,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_int8_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp new file mode 100644 index 000000000..f4dfc8f77 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_scale_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scale_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scale_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scale_int8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp new file mode 100644 index 000000000..c25c492e4 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector, + NDHWGC, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD< + NumDimSpatial, + OutLayout, + WeiLayout, + Tuple<>, + InLayout, + OutDataType, + WeiDataType, + Tuple<>, + InDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale, + ComputeTypeA, + ComputeTypeB>> +{ + using DeviceOp = + DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + Tuple<>, + InDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale, + ComputeTypeA, + ComputeTypeB>; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp new file mode 100644 index 000000000..c4bc1da57 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + int8_t, + int8_t, + ck::Tuple<>, + int8_t, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v && + DLayouts::Size() == 0) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( + op_ptrs); + } +#endif + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt new file mode 100644 index 000000000..b7901a281 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -0,0 +1,6 @@ +set(GROUPED_CONV3D_BWD_DATA_BILINEAR + xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp) + +add_instance_library(device_grouped_conv3d_bwd_data_scale_instance ${GROUPED_CONV3D_BWD_DATA_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 000000000..af94c0ce9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_bf16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_bf16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 000000000..cc8995320 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 000000000..5ed7962bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector, + NDHWGC, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f32_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f32_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt new file mode 100644 index 000000000..45d270d55 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -0,0 +1,7 @@ +set(GROUPED_CONV3D_FWD_BILINEAR + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 000000000..acff3e81b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 000000000..dacbfe678 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 000000000..9e2c1131a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_scale_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp new file mode 100644 index 000000000..f9cbf1c44 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + int8_t, + int8_t, + ck::Tuple<>, + int8_t, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck -- GitLab From e626d5202ab826ee22b369d053ab9d42ab343cff Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:50:03 -0500 Subject: [PATCH 04/63] Add instances for conv_scale with fp8 in/out (#1193) * Add fp8 conv instances and client example * Format * Add example * Update cmakelists * Add profiler mode * Format * Fix copyright headers --- client_example/16_convnd_fwd/CMakeLists.txt | 3 + .../16_convnd_fwd/conv3d_fwd_fp8.cpp | 46 ++++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + example/09_convnd_fwd/convnd_fwd_common.hpp | 91 ++++++++++++++++++- example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp | 81 +++++++++++++++++ ...ed_conv_fwd_bias_relu_add_wmma_example.inc | 7 +- .../device_grouped_conv_fwd_xdl_instance.hpp | 38 +++++++- .../gpu/grouped_convolution_forward.hpp | 23 ++++- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 + ..._xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp | 53 +++++++++++ profiler/src/profile_grouped_conv_fwd.cpp | 11 ++- 11 files changed, 349 insertions(+), 10 deletions(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 5279e3dfc..e2797415e 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -7,6 +7,9 @@ endif() if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp) + target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) endif() if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp new file mode 100644 index 000000000..2506e29e0 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index f9903bfe0..a3f63350f 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) + add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 109b8f9ee..b0fd6a382 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -27,6 +27,88 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + template (), + get_atol()); } return true; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp new file mode 100644 index 000000000..ef130148b --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using ComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeDataType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index ca8746bb9..3248c5fa4 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting @@ -279,8 +279,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) switch(conv_param.num_dim_spatial_) { // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); - case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); - // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + case 2: + return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); + // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); } return false; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 56b362eb9..e6040e0d9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -214,6 +214,42 @@ using device_grouped_conv_fwd_xdl_f16_comp_f8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 1be5c324c..7d3071c17 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -727,6 +727,21 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance PassThrough, PassThrough, F8>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -1137,6 +1152,12 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs); + } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 540ce3410..998c1a51a 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -30,4 +30,9 @@ if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp) endif() +if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp new file mode 100644 index 000000000..48ec4397b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index d0b424cde..7dff5bf5c 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -23,6 +23,7 @@ enum struct ConvDataType F16_F16_F16, // 1 BF16_BF16_BF16, // 2 INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 }; #define OP_NAME "grouped_conv_fwd" @@ -36,7 +37,8 @@ static void print_helper_msg() << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << " 1: Input fp16, Weight fp16, Output fp16\n" << " 2: Input bf16, Weight bf16, Output bf16\n" - << " 3: Input int8, Weight int8, Output int8)\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -79,6 +81,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using F16 = ck::half_t; using BF16 = ck::bhalf_t; using INT8 = int8_t; + using F8 = ck::f8_t; // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -250,6 +253,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F8_F8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}); + } } std::cout << "this data_type & layout is not implemented" << std::endl; -- GitLab From bdcd037428ac356e5b77271b7b6669c5c2d9548a Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 18 Mar 2024 09:48:29 -0700 Subject: [PATCH 05/63] Re-enable the performance tracking in CI. (#1203) * test CK with rocm6.1 RC2 * add docker credentials for pull * update the performance db name * use environment variable for db name * add rocm-llvm-dev package to ck docker * turn off verification for daily performance runs * do not stash ckProfiler on MI300 node * add processing of mixed gemms to qa, fix parsing of splitk gemm logs * fix the splitk gemm log file name * turn the timing on for splitk gemm performance --- Dockerfile | 19 ++++++----- Jenkinsfile | 47 +++++++++++++++------------- script/process_perf_data.py | 9 ++++-- script/run_full_performance_tests.sh | 26 +++++---------- 4 files changed, 52 insertions(+), 49 deletions(-) diff --git a/Dockerfile b/Dockerfile index 38f234943..e3e791729 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,17 +16,17 @@ RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN if [ "$ROCMVERSION" != "6.0.1" ]; then \ +RUN if [ "$ROCMVERSION" != "6.1" ]; then \ sh -c "wget https://repo.radeon.com/amdgpu-install/6.0/ubuntu/focal/amdgpu-install_6.0.60000-1_all.deb --no-check-certificate" && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.0.60000-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "6.0.1" ] && [ "$compiler_version" = "rc1" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.0-20.04-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.0-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.0.1 rel-95 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1704947; \ + elif [ "$ROCMVERSION" = "6.1" ] && [ "$compiler_version" = "rc2" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.1-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.1-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.1 rel-48 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1736298; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" @@ -41,6 +41,7 @@ chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} # Install dependencies +# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ build-essential \ cmake \ @@ -60,6 +61,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- python3-dev \ python3-pip \ redis \ + rocm-llvm-dev \ sshpass \ stunnel \ software-properties-common \ @@ -73,6 +75,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* +# Update the cmake to version 3.27.5 +RUN pip install --upgrade cmake==3.27.5 + #Install latest ccache RUN git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install @@ -82,8 +87,6 @@ RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releas RUN gunzip /usr/local/bin/ninja.gz RUN chmod a+x /usr/local/bin/ninja RUN git clone https://github.com/nico/ninjatracing.git -# Update the cmake to the latest version -RUN pip install --upgrade cmake==3.27.5 #Install latest cppcheck RUN git clone https://github.com/danmar/cppcheck.git && \ diff --git a/Jenkinsfile b/Jenkinsfile index abecb7640..e60bae2b6 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -38,7 +38,7 @@ def getDockerImageName(){ img = "${params.USE_CUSTOM_DOCKER}" } else{ - if (params.ROCMVERSION != "6.0.1"){ + if (params.ROCMVERSION != "6.1"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -117,7 +117,9 @@ def getDockerImage(Map conf=[:]){ { echo "Pulling down image: ${image}" retimage = docker.image("${image}") - retimage.pull() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.pull() + } } catch(Exception ex) { @@ -406,7 +408,7 @@ def runCKProfiler(Map conf=[:]){ dir("script"){ if (params.RUN_FULL_QA){ - sh "./run_full_performance_tests.sh 1 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" + sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" archiveArtifacts "perf_gemm.log" archiveArtifacts "perf_resnet50_N256.log" archiveArtifacts "perf_resnet50_N4.log" @@ -416,9 +418,9 @@ def runCKProfiler(Map conf=[:]){ archiveArtifacts "perf_conv_bwd_data.log" archiveArtifacts "perf_gemm_bilinear.log" archiveArtifacts "perf_reduction.log" - archiveArtifacts "perf_splitK_gemm_verify.log" archiveArtifacts "perf_splitK_gemm.log" archiveArtifacts "perf_onnx_gemm.log" + archiveArtifacts "perf_mixed_gemm.log" // stash perf files to master stash name: "perf_gemm.log" stash name: "perf_resnet50_N256.log" @@ -431,6 +433,7 @@ def runCKProfiler(Map conf=[:]){ stash name: "perf_reduction.log" stash name: "perf_splitK_gemm.log" stash name: "perf_onnx_gemm.log" + stash name: "perf_mixed_gemm.log" //we will process results on the master node } else{ @@ -493,9 +496,6 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - def navi_node = 0 - def mi300_node = 0 - gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) @@ -508,14 +508,6 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ - navi_node = 1 - echo "This is a Navi node" - } - if ( runShell('grep -n "gfx942" rocminfo.log') ){ - mi300_node = 1 - echo "This is MI300 node" - } } } } @@ -526,15 +518,27 @@ def Build_CK(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { + //check whether running on Navi or MI300 node + def navi_node = 0 + def mi300_node = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ + navi_node = 1 + echo "This is a Navi node" + } + if ( runShell('grep -n "gfx942" rocminfo.log') ){ + mi300_node = 1 + echo "This is MI300 node" + } cmake_build(conf) dir("build"){ //run tests and examples sh 'make -j check' - if (navi_node == 0 ){ + if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){ //we only need the ckProfiler to run the performance tests, so we pack and stash it - //do not stash profiler on Navi nodes + //do not stash profiler on Navi or MI300 nodes sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash "ckProfiler.tar.gz" + stash name: "ckProfiler.tar.gz" } if (params.RUN_FULL_QA && mi300_node == 0 ){ // build deb packages for all MI100/200/300 targets and prepare to export @@ -542,7 +546,7 @@ def Build_CK(Map conf=[:]){ archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' archiveArtifacts artifacts: 'composablekernel-tests_*.deb' sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash "ckprofiler_0.2.0_amd64.deb" + stash name: "ckprofiler_0.2.0_amd64.deb" } } if (params.hipTensor_test && navi_node == 0 ){ @@ -629,6 +633,7 @@ def process_results(Map conf=[:]){ unstash "perf_reduction.log" unstash "perf_splitK_gemm.log" unstash "perf_onnx_gemm.log" + unstash "perf_mixed_gemm.log" sh "./process_qa_data.sh" unstash "ckprofiler_0.2.0_amd64.deb" sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" @@ -716,8 +721,8 @@ pipeline { description: "Run the cppcheck static analysis (default: OFF)") booleanParam( name: "RUN_PERFORMANCE_TESTS", - defaultValue: false, - description: "Run the performance tests (default: OFF)") + defaultValue: true, + description: "Run the performance tests (default: ON)") booleanParam( name: "RUN_CODEGEN_TESTS", defaultValue: true, diff --git a/script/process_perf_data.py b/script/process_perf_data.py index d7e40569f..2c46da8fd 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -133,11 +133,16 @@ def parse_logfile(logfile): if 'Best Perf' in line: lst=line.split() res.append(lst[4]) - elif 'onnx_gemm' in logfile or 'splitK_gemm' in logfile or 'mixed_gemm' in logfile: + elif 'onnx_gemm' in logfile or 'mixed_gemm' in logfile: for line in open(logfile): if 'Best Perf' in line: lst=line.split() res.append(lst[33]) + elif 'splitK_gemm' in logfile: + for line in open(logfile): + if 'Best Perf' in line: + lst=line.split() + res.append(lst[36]) return res @@ -231,7 +236,7 @@ def main(): sql_hostname = '127.0.0.1' sql_username = os.environ["dbuser"] sql_password = os.environ["dbpassword"] - sql_main_database = 'miopen_perf' + sql_main_database = os.environ["ck_perf_db"] sql_port = 3306 ssh_host = os.environ["dbsship"] ssh_user = os.environ["dbsshuser"] diff --git a/script/run_full_performance_tests.sh b/script/run_full_performance_tests.sh index 90678389f..01ac1b0a3 100755 --- a/script/run_full_performance_tests.sh +++ b/script/run_full_performance_tests.sh @@ -121,26 +121,16 @@ print_log_header $reduction_log $env_type $branch $host_name ./profile_reduce_no_index.sh $verify 2 10 --half 2>&1 | tee -a $reduction_log #run splitK_gemm tests, first correctness verification, then performance -export splitK_gemm_ver_log="perf_splitK_gemm_verify.log" -print_log_header $splitK_gemm_ver_log $env_type $branch $host_name -./profile_splitK_gemm.sh gemm_splitk 0 0 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 0 1 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 0 2 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 0 3 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 0 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 1 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 2 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 3 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log export splitK_gemm_log="perf_splitK_gemm.log" print_log_header $splitK_gemm_log $env_type $branch $host_name -./profile_splitK_gemm.sh gemm_splitk 0 0 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 0 1 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 0 2 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 0 3 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 0 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 1 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 2 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 3 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 0 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 1 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 2 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 3 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 0 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 1 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 2 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 3 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log #run ONNX gemm tests export onnx_log="perf_onnx_gemm.log" -- GitLab From 9e011bcd6e7735fbcd9045bbd7f2fb98df1446a0 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 18 Mar 2024 10:16:45 -0700 Subject: [PATCH 06/63] update the changelog for ROCm6.1 release (#1205) * update the changelog for ROCm6.1 release * modifty the order of items in changelog, capitalize GEMMs --- CHANGELOG.md | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e3feed2d..fb2ba1975 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,20 +2,27 @@ Full documentation for Composable Kernel is not yet available. -## (Unreleased) CK +## CK for ROCm 6.1.0 -### Fixes -None +### Additions +* Added generic instances for GEMM XDL operations (#1161) +* Added gamma and beta parameters for the layernorm and groupnorm bwd operations (#1133) +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) +* Added an option to vary the number of warm-up cycles and iterations for ckProfiler (#1124) ### Optimizations -None +* New performance optimizations for GEMM operations on MI200 and MI300 architectures (#1135) -### Additions -* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126, #1139) +### Fixes +* Reduced the build time for most GPU architectures (#1084) +* Fixed some conversion issues for fp8 data type (#1099) ### Changes None +### Known issues +None + ## CK for ROCm 6.0.0 ### Fixes @@ -32,7 +39,7 @@ None * Grouped convolution support for small K and C (#822 #879 #897) * Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) * Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) -* Support for Batched Gemm DL (#732) +* Support for Batched GEMM DL (#732) ### Changes * Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) @@ -48,7 +55,7 @@ None ### Additions * New CMake flags: - * "DL_KERNELS"-* Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances + * "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types * "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler * New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler -- GitLab From f52109531b539a9dc8f7f744a104e10558288946 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 19 Mar 2024 08:38:52 -0700 Subject: [PATCH 07/63] Fix a couple of docker issues. (#1206) * do not install sccache by default, only install rocm-llvm-dev for rocm6.1 * add sccache flag to docker build options --- Dockerfile | 18 ++++++++++++------ Jenkinsfile | 9 +++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/Dockerfile b/Dockerfile index e3e791729..cc8b1eadf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ ARG DEBIAN_FRONTEND=noninteractive ARG ROCMVERSION=6.0 ARG compiler_version="" ARG compiler_commit="" +ARG CK_SCCACHE="" RUN set -xe @@ -32,16 +33,18 @@ RUN if [ "$ROCMVERSION" != "6.1" ]; then \ RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN amdgpu-install -y --usecase=rocm --no-dkms -## Sccache binary built from source for ROCm +## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin -RUN mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ -curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ -chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} +ENV CK_SCCACHE=$CK_SCCACHE +RUN if [ "$CK_SCCACHE" != "" ]; then \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ + fi # Install dependencies -# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ build-essential \ cmake \ @@ -61,7 +64,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- python3-dev \ python3-pip \ redis \ - rocm-llvm-dev \ sshpass \ stunnel \ software-properties-common \ @@ -75,6 +77,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* +# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 +RUN if [ "$ROCMVERSION" = "6.1" ]; then \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev"; \ + fi # Update the cmake to version 3.27.5 RUN pip install --upgrade cmake==3.27.5 diff --git a/Jenkinsfile b/Jenkinsfile index e60bae2b6..ec3cbd0e2 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -104,7 +104,7 @@ def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 def prefixpath = conf.get("prefixpath", "/opt/rocm") def no_cache = conf.get("no_cache", false) - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if(no_cache) { dockerArgs = dockerArgs + " --no-cache " @@ -134,7 +134,7 @@ def buildDocker(install_prefix){ checkout scm def image_name = getDockerImageName() echo "Building Docker for ${image_name}" - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " echo "Build Args: ${dockerArgs}" try{ @@ -311,7 +311,7 @@ def buildHipClangJob(Map conf=[:]){ if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } - def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -367,9 +367,6 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " - } def variant = env.STAGE_NAME def retimage -- GitLab From 9e5042691539ba6731158c5c7b83fff4a25f7715 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 09:28:03 -0600 Subject: [PATCH 08/63] Bump rocm-docs-core from 0.36.0 to 0.37.0 in /docs/sphinx (#1208) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.36.0 to 0.37.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.36.0...v0.37.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index b3c826773..ae92cc6c1 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.36.0 +rocm-docs-core==0.37.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index ba1d7da44..43853dd3f 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -96,9 +96,7 @@ pygments==2.15.0 # pydata-sphinx-theme # sphinx pyjwt[crypto]==2.6.0 - # via - # pygithub - # pyjwt + # via pygithub pynacl==1.5.0 # via pygithub pytz==2023.3.post1 @@ -113,7 +111,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.36.0 +rocm-docs-core==0.37.0 # via -r requirements.in six==1.16.0 # via -- GitLab From fd0d093e78c18197a4f1b7dafdbc1e2438d28317 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 21 Mar 2024 13:57:34 -0500 Subject: [PATCH 09/63] Add instances for conv_scale with bf8 in / fp8 out (#1200) * Add bf8 conv fwd instances * Add example * Add profiler mode * Add client example * Fix copyright headers * Format --- client_example/16_convnd_fwd/CMakeLists.txt | 5 ++ .../16_convnd_fwd/conv3d_fwd_bf8.cpp | 46 +++++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp | 81 +++++++++++++++++++ .../device_grouped_conv_fwd_xdl_instance.hpp | 40 +++++++++ .../gpu/grouped_convolution_forward.hpp | 24 ++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 ++ ..._xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp | 53 ++++++++++++ profiler/src/profile_grouped_conv_fwd.cpp | 9 ++- 9 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e2797415e..e034c468d 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -12,6 +12,11 @@ if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) endif() +if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp) + target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) +endif() + if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp new file mode 100644 index 000000000..983e0d083 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index a3f63350f..195f1857e 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -7,6 +7,7 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) + add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp new file mode 100644 index 000000000..0fc9e7b5d --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using ComputeType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index e6040e0d9..0f845ca1e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -17,6 +17,10 @@ namespace instance { using F8 = ck::f8_t; #endif +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; @@ -250,6 +254,42 @@ using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_BF8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 7d3071c17..b9712542a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -744,6 +744,23 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( F8>>>& instances); #endif +#ifdef CK_ENABLE_BF8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 998c1a51a..3825b92af 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -35,4 +35,9 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp) endif() +if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp new file mode 100644 index 000000000..9f1ceae80 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 7dff5bf5c..1f7273372 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -24,6 +24,7 @@ enum struct ConvDataType BF16_BF16_BF16, // 2 INT8_INT8_INT8, // 3 F8_F8_F8, // 4 + BF8_BF8_F8, // 5 }; #define OP_NAME "grouped_conv_fwd" @@ -38,7 +39,8 @@ static void print_helper_msg() << " 1: Input fp16, Weight fp16, Output fp16\n" << " 2: Input bf16, Weight bf16, Output bf16\n" << " 3: Input int8, Weight int8, Output int8\n" - << " 4: Input fp8, Weight fp8, Output fp8)\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -82,6 +84,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using BF16 = ck::bhalf_t; using INT8 = int8_t; using F8 = ck::f8_t; + using BF8 = ck::bf8_t; // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -257,6 +260,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}); } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}); + } } std::cout << "this data_type & layout is not implemented" << std::endl; -- GitLab From 9c052804a75491865bab0fad49d059b6e4e98cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 22 Mar 2024 10:40:43 +0100 Subject: [PATCH 10/63] Add elementwise with dynamic vector dim (#1198) * Add elementwise with dynamic vector dim * Reduce number of instaces * Fixes * Fixes --- .../elementwise_permute_4D_fp16.cpp | 25 +- .../elementwise_permute_4D_fp16_col.cpp | 56 +- .../elementwise_permute_4D_fp16_row.cpp | 55 +- .../elementwise_permute_4D_fp32_col.cpp | 56 +- .../elementwise_permute_4D_fp32_row.cpp | 55 +- ...hread_group_tensor_slice_transfer_v4r2.hpp | 193 +++++ ...e_elementwise_dynamic_vector_dims_impl.hpp | 422 +++++++++ ...idwise_elementwise_dynamic_vector_dims.hpp | 169 ++++ .../threadwise_tensor_slice_transfer_v3r2.hpp | 804 ++++++++++++++++++ .../gpu/permute_scale.hpp | 116 +-- .../device_permute_scale_instances.hpp | 179 +++- .../gpu/permute_scale/CMakeLists.txt | 18 +- ...evice_permute_scale_1d_fp16_instances.cpp} | 15 +- ...device_permute_scale_1d_fp32_instances.cpp | 24 + ...evice_permute_scale_2d_fp16_instances.cpp} | 15 +- ...device_permute_scale_2d_fp32_instances.cpp | 24 + ...evice_permute_scale_3d_fp16_instances.cpp} | 15 +- ...device_permute_scale_3d_fp32_instances.cpp | 24 + ...evice_permute_scale_4d_fp16_instances.cpp} | 15 +- ...device_permute_scale_4d_fp32_instances.cpp | 24 + ...evice_permute_scale_5d_fp16_instances.cpp} | 15 +- ...device_permute_scale_5d_fp32_instances.cpp | 24 + ...evice_permute_scale_6d_fp16_instances.cpp} | 15 +- ...device_permute_scale_6d_fp32_instances.cpp | 24 + .../profiler/profile_permute_scale_impl.hpp | 46 +- profiler/src/profile_permute_scale.cpp | 29 +- script/profile_permute_scale.sh | 43 + test/permute_scale/test_permute_scale.cpp | 24 +- 28 files changed, 2161 insertions(+), 363 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_1d_instances.cpp => device_permute_scale_1d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_2d_instances.cpp => device_permute_scale_2d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_3d_instances.cpp => device_permute_scale_3d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_4d_instances.cpp => device_permute_scale_4d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_5d_instances.cpp => device_permute_scale_5d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_6d_instances.cpp => device_permute_scale_6d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp create mode 100755 script/profile_permute_scale.sh diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 8e9bc64ab..1b28a901c 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,15 +20,20 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // Elementwise op - 4, // NumDim - 8, // MPerThread - ck::Sequence<8>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // Elementwise + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq template void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index 9d5fdc0cc..f832601f0 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -7,7 +7,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -21,26 +21,23 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 8, // MPerThread - ck::Sequence<1>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq - -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { std::size_t N = A_nchw.mDesc.GetLengths()[0]; std::size_t C = A_nchw.mDesc.GetLengths()[1]; @@ -51,11 +48,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, for(std::size_t c = 0; c < C; ++c) for(std::size_t n = 0; n < N; ++n) { - ADataType tmp_val; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor_b(tmp_val, a_val); - functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], - scale * tmp_val); + functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); } } @@ -104,14 +98,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -143,7 +131,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 7d215cef2..bae85f53c 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,36 +20,31 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 8, // MPerThread - ck::Sequence<8>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq - -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) { - ADataType tmp_val; auto a_val = A_nchw(n, c, h, w); - functor_b(tmp_val, a_val); - functor_a(B_nhwc(n, h, w, c), scale * tmp_val); + functor(B_nhwc(n, h, w, c), a_val); } } @@ -86,14 +81,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -125,7 +114,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index 69e411c59..fe7acd301 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,26 +20,23 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 1, // MPerThread - ck::Sequence<1>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq - -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<1>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { std::size_t N = A_nchw.mDesc.GetLengths()[0]; std::size_t C = A_nchw.mDesc.GetLengths()[1]; @@ -50,11 +47,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, for(std::size_t c = 0; c < C; ++c) for(std::size_t n = 0; n < N; ++n) { - ADataType tmp_val; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor_b(tmp_val, a_val); - functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], - scale * tmp_val); + functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); } } @@ -104,14 +98,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -143,7 +131,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index 69f40fe16..aebdb37d9 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,36 +20,31 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 8, // MPerThread - ck::Sequence<8>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq - -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) { - ADataType tmp_val; auto a_val = A_nchw(n, c, h, w); - functor_b(tmp_val, a_val); - functor_a(B_nhwc(n, h, w, c), scale * tmp_val); + functor(B_nhwc(n, h, w, c), a_val); } } @@ -86,14 +81,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -125,7 +114,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp new file mode 100644 index 000000000..aa1f7c573 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r2 +{ + static constexpr index_t nDim = + remove_reference_t>::GetNumOfDimension(); + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + + { + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffer& src_bufs, + const DstDescs& dst_descs, + DstBuffer& dst_bufs, + Number thread_scratch_id) + { + RunRead(src_descs, src_bufs, thread_scratch_id); + RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp new file mode 100644 index 000000000..4dba95e5d --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseImpl + : public DeviceElementwise +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static index_t GetLowestStrideDim(const std::array& strides) + { + index_t most_continous_dim = NumDim - 1; + index_t most_continous_dim_stride = strides[most_continous_dim]; + for(index_t dim = 0; dim < NumDim; dim++) + { + if(strides[dim] < most_continous_dim_stride) + { + most_continous_dim_stride = strides[dim]; + most_continous_dim = dim; + } + } + return most_continous_dim; + } + + template + static auto PadInputOutputDescriptor(const InOutDescriptor& desc) + { + const auto M0 = desc.GetLength(I0); + const auto M1 = desc.GetLength(I1); + const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0; + const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1; + + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return padded_desc; + } + + static auto GenerateBatchDimsLenghtsTuple(const std::array& lengths, + const index_t M0_dim, + const index_t M1_dim) + { + // Generate batch dims, they will be merged to M0 + // Add one more dim than needed in case that M0 is equal to M1 + // If M0 is equal to M1, then will be one more batch dim + std::array batch_dims; + index_t batch_dim = 0; + for(index_t i = 0; i < NumDim; i++) + { + if(i != M0_dim && i != M1_dim) + { + batch_dims[batch_dim] = lengths[i]; + batch_dim++; + } + } + // Add dummy dim if M0_dim is not equal to M1_dim + if(M0_dim != M1_dim && NumDim >= 2) + batch_dims[NumDim - 2] = 1; + return generate_tuple([&](auto I) { return batch_dims[I]; }, Number{}); + } + + static auto MakeDescriptor(const std::array& lengths, + const std::array& in_strides, + const std::array& out_strides, + const std::array& desc_strides) + { + const auto M0_dim = GetLowestStrideDim(out_strides); + const auto M1_dim = GetLowestStrideDim(in_strides); + + // If M0_dim is equal to M1_dim, then make M0_dim dummy + const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim]; + const auto M1 = lengths[M1_dim]; + const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim]; + const auto M1_stride = desc_strides[M1_dim]; + + const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim); + const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim); + + const auto desc = make_naive_tensor_descriptor( + concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)), + concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride))); + // Merged batch dims with M0 + const auto transforms = + make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))), + make_pass_through_transform(M1)); + using BatchElemsSequence = + typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type; + const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence{}); + const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{}); + // desc: (merged_dims + M0, M1) + auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); + return PadInputOutputDescriptor(merged_desc); + } + + template + static auto GenerateInOutGridDescTuple() + { + std::array ones; + for(index_t d = 0; d < NumDim; d++) + { + ones[d] = 1; + } + + return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); }, + Number{}); + }; + + using InGridDescTuple = decltype(GenerateInOutGridDescTuple()); + using OutGridDescTuple = decltype(GenerateInOutGridDescTuple()); + + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwiseOp = GridwiseElementwise; + + using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op) + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto in_grid_desc_tuple = generate_tuple( + [&](auto src_i) { + // Use Strides from first tensor to assert that M0 dim and + // M1 dim are the same for each tensor. + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.inStridesArray_[src_i]); + }, + Number{}); + + auto out_grid_desc_tuple = generate_tuple( + [&](auto dst_i) { + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.outStridesArray_[dst_i]); + }, + Number{}); + + const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + + const auto block_2_tile_map = Block2TileMap(M0, M1); + const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1); + + const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) == + GetLowestStrideDim(arg.outStridesArray_[I0]); + + const auto kernel = in_out_same_vector_dim + ? kernel_elementwise + : kernel_elementwise; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + in_grid_desc_tuple, + out_grid_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + block_2_tile_map, + arg.elementwise_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]); + const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]); + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector, + index_t M_dim) { + if(scalarPerVector == 1) + { + return true; + } + if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0) + { + return true; + } + return false; + }; + + bool is_valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 && + M1PerThread % InScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 && + M1PerThread % OutScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim); + }); + + return is_valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseImpl<"; + str << NumDim << ", "; + str << BlockSize << ", "; + str << M0PerBlock << ", "; + str << M1PerBlock << ", "; + str << M0PerThread << ", "; + str << M1PerThread << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp new file mode 100644 index 000000000..2a906a143 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op) +{ + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, + out_grid_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + block_2_tile_map, + elementwise_op); +} + +template +struct GridwiseElementwise +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple, + const OutGridDescTuple& out_grid_desc_tuple, + const InDataTypePointerTuple& p_in_global_tuple, + const OutDataTypePointerTuple& p_out_global_tuple, + const Block2TileMap& block_2_tile_map, + const ElementwiseOperation& elementwise_op) + { + + constexpr auto src_datas = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return DataType{}; + }, + Number{}); + + constexpr auto dst_datas = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return DataType{}; + }, + Number{}); + + const auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t m0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); + const index_t m1_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); + const auto thread_grid_offset = + make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + + using ThisThreadBlock = ThisThreadBlock; + // If src and dst have same vector dim, then: + // M0 dim - for src and dst vector load/store + // else: + // M0 dim - for dst vector load + // M1 dim - for src vector store + using SrcDimAccessOrder = Sequence<0, 1>; + using DstDimAccessOrder = + std::conditional_t, Sequence<1, 0>>; + using SrcVectorDim = Number<1>; + using DstVectorDim = std::conditional_t, Number<0>>; + + using ThreadClusterLengths = + Sequence{}, Number{}>; + + auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2< + ThisThreadBlock, + ElementwiseOperation, + uniform_sequence_gen_t(InMemoryDataOperationEnum::Set)>, + Sequence, + ThreadClusterLengths, + ThreadClusterArrangeOrder, + decltype(src_datas), + decltype(dst_datas), + InGridDescTuple, + OutGridDescTuple, + SrcDimAccessOrder, + DstDimAccessOrder, + SrcVectorDim{}, + DstVectorDim{}, + InScalarPerVectorSeq, + OutScalarPerVectorSeq, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t>{in_grid_desc_tuple, + thread_grid_offset, + out_grid_desc_tuple, + thread_grid_offset, + elementwise_op}; + global_to_global_transfer.Run( + in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp new file mode 100644 index 000000000..f0d793456 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + src_coords_(src_i) = + make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + dst_coords_(dst_i) = + make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]); + }); + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access_tuple = generate_tuple( + [&](auto src_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return SliceLengths{} / src_scalar_per_access_tuple.At(src_i); + static_assert( + SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector"); + }, + Number{}); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return container_reorder_given_new2old(src_access_lengths_tuple.At(src_i), + src_dim_access_order); + }, + Number{}); + + // make forward steps + const auto src_forward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto src_backward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -src_scalar_per_access_tuple.At(src_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_ford>{}( + [&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths_tuple[j] + + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_src_access_idx[i] + : ordered_src_access_lengths_tuple.At(src_i)[i] - + 1 - ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access_tuple.At(src_i); + }(); + + constexpr auto src_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_descs.At(src_i), src_coords_.At(src_i)); + + using src_vector_type = vector_type_maker_t, + SrcsScalarPerVector::At(src_i)>; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = + src_vector_type{src_bufs.At(src_i).template Get( + src_coords_.At(src_i).GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .At(src_i) + .template SetAsType( + src_data_idx_seq, + src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < + ordered_src_access_lengths_tuple.At(src_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == + ordered_src_access_lengths_tuple.At(src_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + } + }); + }); + }); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + // move src coordinate back to slice origin (or not) + if constexpr(SrcsResetCoordinateAfterRun::At(src_i)) + { + const auto src_reset_step = make_tensor_coordinate_step( + src_descs.At(src_i), GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), src_reset_step); + } + }); + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { + // TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + // (it requires to add Elementwise support in transpose_vectors) + static_ford{}([&](auto idx) { + const auto src_data_refs = generate_tie( + [&](auto src_i) -> const auto& { + return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; + }, + Number{}); + + auto dst_data_refs = generate_tie( + [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, + Number{}); + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access_tuple = generate_tuple( + [&](auto dst_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); }, + Number{}); + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { + return container_reorder_given_new2old(dst_access_lengths_tuple.At(dst_i), + dst_dim_access_order); + }, + Number{}); + + // make forward steps + const auto dst_forward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -dst_scalar_per_access_tuple.At(dst_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_ford>{}( + [&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] + + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths_tuple.At(dst_i)[i] - + 1 - ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access_tuple.At(dst_i); + }(); + + constexpr auto dst_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + dst_descs.At(dst_i), dst_coords_.At(dst_i)); + + using dst_vector_type = vector_type_maker_t, + DstsScalarPerVector::At(dst_i)>; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_tuple_.At(dst_i).template GetAsType( + dst_data_idx_seq)}; + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(dst_i.value)); + + // copy data from dst_vector_container to dst_buf + dst_bufs.At(dst_i).template Update( + dst_coords_.At(dst_i).GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < + ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == + ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + } + }); + }); + }); + + // move dst coordinate back to slice origin (or not) + static_for<0, nDst, 1>{}([&](auto dst_i) { + if constexpr(DstsResetCoordinateAfterRun::At(dst_i)) + { + const auto dst_reset_step = make_tensor_coordinate_step( + dst_descs.At(dst_i), GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), dst_reset_step); + } + }); + } + + template + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + template + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access.At(dst_i); + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + const Index& src_slice_origin_step_idx) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcsResetCoordinateAfterRun::At(src_i) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(src_descs.At(src_i), adjusted_step_idx); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), adjusted_step); + }); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + const Index& dst_slice_origin_step_idx) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstsResetCoordinateAfterRun::At(dst_i) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(dst_descs.At(dst_i), adjusted_step_idx); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), adjusted_step); + }); + } + + template + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(src_access_lengths), + Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + template + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(dst_access_lengths), + Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto MakeSrcThreadScratchTuple() + { + return generate_tuple( + [&](auto src_i) { + constexpr auto src_thread_scratch_desc = + decltype(GetSrcThreadScratchDescriptor()){}; + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer, + SrcsScalarPerVector::At(src_i), + decltype(src_thread_scratch_desc), + true>; + return SrcThreadScratch{}; + }, + Number{}); + } + + __device__ static constexpr auto MakeDstThreadScratchTuple() + { + return generate_tuple( + [&](auto dst_i) { + constexpr auto dst_thread_scratch_desc = + decltype(GetDstThreadScratchDescriptor()){}; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer, + DstsScalarPerVector::At(dst_i), + decltype(dst_thread_scratch_desc), + true>; + return DstThreadScratch{}; + }, + Number{}); + } + + private: + using SrcThreadScratchTuple = decltype(MakeSrcThreadScratchTuple()); + using DstThreadScratchTuple = decltype(MakeDstThreadScratchTuple()); + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratchTuple dst_thread_scratch_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp index 4b3f40e21..4f5d022f9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp @@ -7,7 +7,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -19,125 +19,67 @@ namespace instance { #ifdef CK_ENABLE_FP16 void add_device_permute_scale_1d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 1>>>&); + std::vector, ck::Tuple, element_wise::Scale, 1>>>&); void add_device_permute_scale_2d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 2>>>&); + std::vector, ck::Tuple, element_wise::Scale, 2>>>&); void add_device_permute_scale_3d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 3>>>&); + std::vector, ck::Tuple, element_wise::Scale, 3>>>&); void add_device_permute_scale_4d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 4>>>&); + std::vector, ck::Tuple, element_wise::Scale, 4>>>&); void add_device_permute_scale_5d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 5>>>&); + std::vector, ck::Tuple, element_wise::Scale, 5>>>&); void add_device_permute_scale_6d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 6>>>&); + std::vector, ck::Tuple, element_wise::Scale, 6>>>&); #endif #ifdef CK_ENABLE_FP32 void add_device_permute_scale_1d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 1>>>&); + std::vector, ck::Tuple, element_wise::Scale, 1>>>&); void add_device_permute_scale_2d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 2>>>&); + std::vector, ck::Tuple, element_wise::Scale, 2>>>&); void add_device_permute_scale_3d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 3>>>&); + std::vector, ck::Tuple, element_wise::Scale, 3>>>&); void add_device_permute_scale_4d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 4>>>&); + std::vector, ck::Tuple, element_wise::Scale, 4>>>&); void add_device_permute_scale_5d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 5>>>&); + std::vector, ck::Tuple, element_wise::Scale, 5>>>&); void add_device_permute_scale_6d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 6>>>&); + std::vector, ck::Tuple, element_wise::Scale, 6>>>&); #endif template struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceElementwise> + ck::tensor_operation::device:: + DeviceElementwise> { - using DeviceOp = DeviceElementwise; + using DeviceOp = + DeviceElementwise; static auto GetInstances() { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp index a672ab22d..8a2200541 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp @@ -2,7 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/utility/data_type.hpp" namespace ck { @@ -13,26 +13,175 @@ namespace instance { using F16 = ck::half_t; using F32 = float; -using Pass = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; - // clang-format off -template +template using device_permute_scale_f16_instances = std::tuple < - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>> + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + +#if 0 + // Disabled instances to improve compilation time + // They listed here to show other possible combinations of parameters + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 512, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 512, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 256, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, +#endif + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> + >; -template +template using device_permute_scale_f32_instances = std::tuple< - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>> + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + +#if 0 + // Disabled instances to improve compilation time + // They listed here to show other possible combinations of parameters + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 512, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 512, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 256, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, +#endif + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt index 86652c0bf..fc0da56a9 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt @@ -1,7 +1,13 @@ add_instance_library(device_permute_scale_instance - device_permute_scale_1d_instances.cpp - device_permute_scale_2d_instances.cpp - device_permute_scale_3d_instances.cpp - device_permute_scale_4d_instances.cpp - device_permute_scale_5d_instances.cpp - device_permute_scale_6d_instances.cpp) + device_permute_scale_1d_fp16_instances.cpp + device_permute_scale_2d_fp16_instances.cpp + device_permute_scale_3d_fp16_instances.cpp + device_permute_scale_4d_fp16_instances.cpp + device_permute_scale_5d_fp16_instances.cpp + device_permute_scale_6d_fp16_instances.cpp + device_permute_scale_1d_fp32_instances.cpp + device_permute_scale_2d_fp32_instances.cpp + device_permute_scale_3d_fp32_instances.cpp + device_permute_scale_4d_fp32_instances.cpp + device_permute_scale_5d_fp32_instances.cpp + device_permute_scale_6d_fp32_instances.cpp) diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp16_instances.cpp index 77d3baf4d..4ee9c1b1c 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_1d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 1>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<1>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_1d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 1>>>& instances) +void add_device_permute_scale_1d_f16_instances( + std::vector, ck::Tuple, Scale, 1>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<1>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<1, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp new file mode 100644 index 000000000..672acda07 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_1d_f32_instances( + std::vector, ck::Tuple, Scale, 1>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<1, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp16_instances.cpp index 399b6b049..b4a5b107f 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_2d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 2>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<2>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_2d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 2>>>& instances) +void add_device_permute_scale_2d_f16_instances( + std::vector, ck::Tuple, Scale, 2>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<2>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<2, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp new file mode 100644 index 000000000..5b7b353fc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_2d_f32_instances( + std::vector, ck::Tuple, Scale, 2>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<2, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp16_instances.cpp index 29f2f9fd5..63876cbc4 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_3d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 3>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<3>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_3d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 3>>>& instances) +void add_device_permute_scale_3d_f16_instances( + std::vector, ck::Tuple, Scale, 3>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<3>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<3, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp new file mode 100644 index 000000000..f8772967d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_3d_f32_instances( + std::vector, ck::Tuple, Scale, 3>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<3, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp16_instances.cpp index 3ad1d59e6..553772e1d 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_4d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<4>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_4d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) +void add_device_permute_scale_4d_f16_instances( + std::vector, ck::Tuple, Scale, 4>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<4>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<4, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp new file mode 100644 index 000000000..f1ecc0ccf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_4d_f32_instances( + std::vector, ck::Tuple, Scale, 4>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<4, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp16_instances.cpp index 6a4383bc9..adb391888 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_5d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 5>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<5>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_5d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 5>>>& instances) +void add_device_permute_scale_5d_f16_instances( + std::vector, ck::Tuple, Scale, 5>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<5>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<5, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp new file mode 100644 index 000000000..ed53e09b7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_5d_f32_instances( + std::vector, ck::Tuple, Scale, 5>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<5, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp16_instances.cpp index 71e5867e9..abf630e43 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_6d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 6>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<6>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_6d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 6>>>& instances) +void add_device_permute_scale_6d_f16_instances( + std::vector, ck::Tuple, Scale, 6>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<6>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<6, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp new file mode 100644 index 000000000..fbdace20a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_6d_f32_instances( + std::vector, ck::Tuple, Scale, 6>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<6, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp index 5bc7c029f..c69e36142 100644 --- a/profiler/include/profiler/profile_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -8,9 +8,9 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" @@ -21,23 +21,12 @@ #include "ck/library/utility/literals.hpp" namespace ck { -template +template void reference_permute_scale(HostTensorB& b_tensor, const HostTensorA& a_tensor, - AElementOp a_tensor_op, - BElementOp b_tensor_op, - ScaleElementOp scale_op) + ElementOp tensor_op) { - b_tensor.ForEach([&](auto& self, auto idx) { - auto tmp_val = a_tensor(idx); - b_tensor_op(tmp_val, tmp_val); - scale_op(tmp_val, tmp_val); - a_tensor_op(self(idx), tmp_val); - }); + b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); }); } namespace profiler { @@ -54,9 +43,7 @@ bool profile_permute_scale_impl(int do_verification, bool pass = true; bool instance_found = false; - using ElementOp = ck::tensor_operation::element_wise::PassThrough; - using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; - using Scale = ck::tensor_operation::element_wise::Scale; + using ElementOp = ck::tensor_operation::element_wise::Scale; float scale = 2.f; Tensor a(lengths_vector, input_strides_vector); @@ -80,12 +67,8 @@ bool profile_permute_scale_impl(int do_verification, std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - using DeviceOp = ck::tensor_operation::device::DeviceElementwise, - ck::Tuple, - ElementOp, - UnaryOp, - Scale, - NumDim>; + using DeviceOp = ck::tensor_operation::device:: + DeviceElementwise, ck::Tuple, ElementOp, NumDim>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -100,7 +83,7 @@ bool profile_permute_scale_impl(int do_verification, if(do_verification) { - reference_permute_scale(host_b, a, ElementOp{}, UnaryOp{}, Scale{scale}); + reference_permute_scale(host_b, a, ElementOp{scale}); } auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; @@ -113,14 +96,8 @@ bool profile_permute_scale_impl(int do_verification, for(auto& op_ptr : op_ptrs) { - auto argument_ptr = op_ptr->MakeArgumentPointer(lengths, - {input_strides}, - {output_strides}, - input, - output, - ElementOp{}, - UnaryOp{}, - Scale{scale}); + auto argument_ptr = op_ptr->MakeArgumentPointer( + lengths, {input_strides}, {output_strides}, input, output, ElementOp{scale}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -141,6 +118,7 @@ bool profile_permute_scale_impl(int do_verification, if(do_log) { LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_b: ", host_b.mData, ",") << std::endl; LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; } } diff --git a/profiler/src/profile_permute_scale.cpp b/profiler/src/profile_permute_scale.cpp index 921b9b9a6..8ebb2289e 100644 --- a/profiler/src/profile_permute_scale.cpp +++ b/profiler/src/profile_permute_scale.cpp @@ -37,6 +37,20 @@ static void print_helper_msg() // clang-format on } +void init_strides(const std::vector& lengths, + const std::vector& dims_order, + std::vector& strides) +{ + + ck::index_t stride = 1; + for(ck::index_t d = lengths.size() - 1; d >= 0; d--) + { + ck::index_t dim = dims_order[d]; + strides[dim] = stride; + stride *= lengths[dim]; + } +} + } // namespace int profile_permute_scale(int argc, char* argv[]) @@ -58,16 +72,21 @@ int profile_permute_scale(int argc, char* argv[]) const int num_dims = dims_argc / 3; std::vector lengths(num_dims); - std::vector input_strides(num_dims); - std::vector output_strides(num_dims); + std::vector input_dims_order(num_dims); + std::vector output_dims_order(num_dims); for(int i = 0; i < num_dims; i++) { - lengths[i] = std::stoi(argv[control_argc + i]); - input_strides[i] = std::stoi(argv[control_argc + num_dims + i]); - output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]); + lengths[i] = std::stoi(argv[control_argc + i]); + input_dims_order[i] = std::stoi(argv[control_argc + num_dims + i]); + output_dims_order[i] = std::stoi(argv[control_argc + 2 * num_dims + i]); } + std::vector input_strides(num_dims); + std::vector output_strides(num_dims); + init_strides(lengths, input_dims_order, input_strides); + init_strides(lengths, output_dims_order, output_strides); + using F32 = float; using F16 = ck::half_t; diff --git a/script/profile_permute_scale.sh b/script/profile_permute_scale.sh new file mode 100755 index 000000000..945d10f47 --- /dev/null +++ b/script/profile_permute_scale.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +TIME=$6 + + +# 1D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 67108864 0 0 + +# # 2D +# ######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8192 8192 0 1 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8192 8192 1 0 0 1 + +# 3D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 1024 8192 0 1 2 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 1024 8192 2 1 0 0 1 2 + +# 4D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 512 8192 0 1 2 3 3 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 512 8192 3 2 1 0 0 1 2 3 + +# 5D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 256 8192 0 1 2 3 4 4 3 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 256 8192 4 3 2 1 0 0 1 2 3 4 + + # 6D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 2 128 8192 0 1 2 3 4 5 5 4 3 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 2 128 8192 5 4 3 2 1 0 0 1 2 3 4 5 + diff --git a/test/permute_scale/test_permute_scale.cpp b/test/permute_scale/test_permute_scale.cpp index 780f6d6ed..e40d4861c 100644 --- a/test/permute_scale/test_permute_scale.cpp +++ b/test/permute_scale/test_permute_scale.cpp @@ -52,40 +52,40 @@ TYPED_TEST_SUITE(TestPermute, KernelTypes); TYPED_TEST(TestPermute, Test1D) { constexpr ck::index_t NumDims = 1; - this->template Run({8}, {1}, {2}); - this->template Run({8}, {2}, {1}); + this->template Run({16}, {1}, {1}); + this->template Run({16}, {1}, {2}); this->template Run({1}, {1}, {1}); } TYPED_TEST(TestPermute, Test2D) { constexpr ck::index_t NumDims = 2; - this->template Run({8, 4}, {4, 1}, {1, 8}); - this->template Run({8, 4}, {1, 8}, {4, 1}); + this->template Run({8, 16}, {16, 1}, {1, 8}); + this->template Run({8, 16}, {1, 8}, {16, 1}); this->template Run({1, 1}, {1, 1}, {1, 1}); } TYPED_TEST(TestPermute, Test3D) { constexpr ck::index_t NumDims = 3; - this->template Run({2, 4, 4}, {16, 4, 1}, {1, 2, 8}); - this->template Run({2, 4, 4}, {1, 2, 8}, {16, 4, 1}); + this->template Run({8, 2, 8}, {16, 8, 1}, {1, 8, 16}); + this->template Run({8, 2, 8}, {1, 8, 16}, {16, 8, 1}); this->template Run({1, 1, 1}, {1, 1, 1}, {1, 1, 1}); } TYPED_TEST(TestPermute, Test4D) { constexpr ck::index_t NumDims = 4; - this->template Run({2, 4, 4, 4}, {64, 16, 4, 1}, {1, 2, 8, 32}); - this->template Run({2, 4, 4, 4}, {1, 2, 8, 32}, {64, 16, 4, 1}); + this->template Run({8, 2, 3, 8}, {48, 24, 8, 1}, {1, 8, 16, 48}); + this->template Run({8, 2, 3, 8}, {1, 8, 16, 48}, {48, 24, 8, 1}); this->template Run({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}); } TYPED_TEST(TestPermute, Test5D) { constexpr ck::index_t NumDims = 5; - this->template Run({2, 4, 4, 4, 4}, {256, 64, 16, 4, 1}, {1, 2, 8, 32, 128}); - this->template Run({2, 4, 4, 4, 4}, {1, 2, 8, 32, 128}, {256, 64, 16, 4, 1}); + this->template Run({8, 2, 3, 4, 8}, {192, 96, 32, 8, 1}, {1, 8, 16, 48, 192}); + this->template Run({8, 2, 3, 4, 8}, {1, 8, 16, 48, 192}, {192, 96, 32, 8, 1}); this->template Run({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}); } @@ -93,8 +93,8 @@ TYPED_TEST(TestPermute, Test6D) { constexpr ck::index_t NumDims = 6; this->template Run( - {2, 4, 4, 4, 4, 4}, {1024, 256, 64, 16, 4, 1}, {1, 2, 8, 32, 128, 512}); + {8, 2, 3, 4, 5, 8}, {960, 480, 160, 40, 8, 1}, {1, 8, 16, 48, 192, 960}); this->template Run( - {2, 4, 4, 4, 4, 4}, {1, 2, 8, 32, 128, 512}, {1024, 256, 64, 16, 4, 1}); + {8, 2, 3, 4, 5, 8}, {1, 8, 16, 48, 192, 960}, {960, 480, 160, 40, 8, 1}); this->template Run({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}); } -- GitLab From 2ae16e901f75594022848a05ba9c1b6d0e3e4d6d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 07:58:36 -0700 Subject: [PATCH 11/63] Bump rocm-docs-core from 0.37.0 to 0.37.1 in /docs/sphinx (#1211) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.37.0 to 0.37.1. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.37.0...v0.37.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index ae92cc6c1..76ec2700c 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.37.0 +rocm-docs-core==0.37.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 43853dd3f..ab2415f0c 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -111,7 +111,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.37.0 +rocm-docs-core==0.37.1 # via -r requirements.in six==1.16.0 # via -- GitLab From cc1f733d0eaab81c2185888668479fb30b200bdb Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:39:11 -0700 Subject: [PATCH 12/63] allow the CI to pass even if can't connect to db (#1214) --- Jenkinsfile | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index ec3cbd0e2..654c7274f 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -619,6 +619,8 @@ def process_results(Map conf=[:]){ dir("script"){ if (params.RUN_FULL_QA){ // unstash perf files to master + unstash "ckprofiler_0.2.0_amd64.deb" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" unstash "perf_gemm.log" unstash "perf_resnet50_N256.log" unstash "perf_resnet50_N4.log" @@ -632,8 +634,6 @@ def process_results(Map conf=[:]){ unstash "perf_onnx_gemm.log" unstash "perf_mixed_gemm.log" sh "./process_qa_data.sh" - unstash "ckprofiler_0.2.0_amd64.deb" - sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" } else{ // unstash perf files to master @@ -645,10 +645,13 @@ def process_results(Map conf=[:]){ } } catch(e){ - echo "throwing error exception while processing performance test results" + echo "Throwing error exception while processing performance test results" echo 'Exception occurred: ' + e.toString() throw e } + finally{ + echo "Finished processing performance test results" + } } } } -- GitLab From 5f2c89e8b43d670e3405a4f17ff475d25960f9b3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:23:54 -0700 Subject: [PATCH 13/63] Bump rocm-docs-core from 0.37.1 to 0.38.0 in /docs/sphinx (#1218) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.37.1 to 0.38.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.37.1...v0.38.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 76ec2700c..2b28fcdd3 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.37.1 +rocm-docs-core==0.38.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index ab2415f0c..335d6e5e0 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -111,7 +111,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.37.1 +rocm-docs-core==0.38.0 # via -r requirements.in six==1.16.0 # via -- GitLab From 303d4594f4c086e15f2cf5fc7fcb00cae6a49c15 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 2 Apr 2024 11:02:52 -0500 Subject: [PATCH 14/63] improved zeroing (#1221) --- example/15_grouped_gemm/CMakeLists.txt | 4 +- .../grouped_gemm_xdl_fixed_nk_fp16.cpp | 10 +- ...=> grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp} | 4 +- .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 219 +++++++++----- ...se_gemm_multiple_d_xdl_splitk_cshuffle.hpp | 277 ++++++++++++------ 5 files changed, 345 insertions(+), 169 deletions(-) rename example/15_grouped_gemm/{grouped_gemm_xdl_fixed_nk_fp8.cpp => grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp} (99%) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 84040fcf5..550dafb06 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16) add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8) -add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp) -add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8) +add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp) +add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8) if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 2c1feafce..1a2bcfb33 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -36,7 +36,7 @@ using BDataType = F16; using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple<>; -using EDataType = F32; +using EDataType = F16; using ALayout = Row; using BLayout = Col; @@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; // clang-format on struct ProblemSize final @@ -298,9 +298,9 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ms.push_back(256 + 256 * i); - problem_size.Ns.push_back(256); - problem_size.Ks.push_back(128); + problem_size.Ms.push_back(128 + rand() % 128); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(1024); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp similarity index 99% rename from example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp rename to example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 9fd63cba7..0a63a2984 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F16; using BDataType = F8; using AccDataType = F32; -using CShuffleDataType = F32; +using CShuffleDataType = F16; using DsDataType = ck::Tuple<>; using EDataType = F16; @@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on struct ProblemSize final diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index d197c56ab..c98ec6e2a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -23,6 +23,7 @@ namespace device { template (gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - barrier_count_finished, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); + if constexpr(Zeroing) + { + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { + + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } id_off += grid_size_grp; id_local += grid_size_grp; @@ -193,8 +224,11 @@ template + PipelineVersion PipelineVer = PipelineVersion::v1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename ComputeType = ADataType, + typename ALDSType = ComputeType, + typename BLDSType = ComputeType> struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK{}; static constexpr auto I2 = Number<2>{}; + using AComputeType = ComputeType; + using BComputeType = ComputeType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, - ComputeType, + AComputeType, + BComputeType, AccDataType, CShuffleDataType, DsDataType, @@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK; + LoopSched, + PipelineVer, + ALDSType, + BLDSType>; template struct OffsettedBlockToCTileMapMLoops @@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK, - GemmSpec, - ALayout, - BLayout, - DsLayout, - ELayout, - DsDataType, - Block2ETileMap, - GroupedGemmBlock2ETileMap, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - e_global_memory_operation_, - has_main_k_block_loop_>; - - return launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - reinterpret_cast(arg.p_workspace_), - arg.barrier_size_grp_, - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.k_batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + if(arg.k_batch_ == 1) + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + false, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + nullptr, + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + true, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } }; constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; constexpr auto Set = InMemoryDataOperationEnum::Set; - // For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced - // in IsSupportedArgument function + // For bf16 datatype only kbatch = 1 scenario is supported. This condition is + // enforced in IsSupportedArgument function if constexpr(std::is_same::value) { if(has_main_k_block_loop) @@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK + PipelineVersion PipelineVer, + typename ALDSType, + typename BLDSType> struct GridwiseGemmMultipleD_xdl_splitk_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ComputeType), + return math::max(a_block_space_size_aligned * sizeof(ALDSType) + + b_block_space_size_aligned * sizeof(BLDSType), c_block_size * sizeof(CShuffleDataType)); } @@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle InMemoryDataOperationEnum EGlobalMemoryDataOperation, index_t NumDTensor_, typename DsDataType_, + bool Zeroing, typename AGridDesc_KBatch_AK0_M_AK1, typename BGridDesc_KBatch_BK0_N_BK1, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, @@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ADataType, - ComputeType, + ALDSType, decltype(a_grid_desc_kbatch_ak0_m_ak1), decltype(a_block_desc_kbatch_ak0_m_ak1), ABlockTransferSrcAccessOrder, @@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BDataType, - ComputeType, + BLDSType, decltype(b_grid_desc_kbatch_bk0_n_bk1), decltype(b_block_desc_kbatch_bk0_n_bk1), BBlockTransferSrcAccessOrder, @@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // sanity check constexpr index_t KPack = math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, - ComputeType, + ALDSType, + BLDSType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeType, + BComputeType>(); -#if 1 - if(block_work_idx[I0] == 0) + if constexpr(Zeroing) { - const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; - const index_t numNThreads = NPerBlock / nThreadSize; - const index_t numMThreads = BlockSize / numNThreads; - const index_t mThreadSize = MPerBlock / numMThreads; - - const index_t m_tid = get_thread_local_1d_id() / numNThreads; - const index_t n_tid = get_thread_local_1d_id() % numNThreads; - - auto c_thread_desc_mblock_mperblock_nblock_nperblock = - make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, I1, Number{})); - - StaticBuffer - e_thread_zero_buf; - - auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< - EDataType, - EDataType, - decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), - decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), - ck::tensor_operation::element_wise::PassThrough, - Sequence<1, mThreadSize, 1, nThreadSize>, - Sequence<0, 1, 2, 3>, - 3, - CDEShuffleBlockTransferScalarPerVector_NPerBlock, - InMemoryDataOperationEnum::Set, - 1, - true>{e_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I1], - m_tid * mThreadSize, - block_work_idx[I2], - n_tid * nThreadSize), - ck::tensor_operation::element_wise::PassThrough{}}; - - c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, - make_tuple(I0, I0, I0, I0), - e_thread_zero_buf, - e_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_buf); - - __syncthreads(); - - if(threadIdx.x == 0) + if(block_work_idx[I0] == 0) { - atomicAdd(barrier_count_finished, 1); + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __builtin_amdgcn_s_barrier(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } } } -#endif auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); @@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // shuffle C and write out { - if(threadIdx.x == 0) + if constexpr(Zeroing) { - while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + __builtin_amdgcn_s_barrier(); } - __syncthreads(); - static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -951,13 +960,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle } }); - if(threadIdx.x == 0) + if constexpr(Zeroing) { - index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); - - if(k_id_finished_t == KBatch) + if(threadIdx.x == 0) { - *barrier_count_finished = 0; + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } } } } @@ -971,24 +983,24 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle typename DsLayout, typename ELayout, typename Block2ETileMap> - __device__ static void Run(const void* __restrict__ p_a_grid_, - const void* __restrict__ p_b_grid_, - DsGridPointer p_ds_grid, - void* __restrict__ p_e_grid_, - void* __restrict__ p_shared, - uint32_t* barrier_count_finished, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CDEElementwiseOperation& cde_element_op, - const index_t M, - const index_t N, - const index_t K, - const index_t StrideA, - const index_t StrideB, - const std::array StrideDs, - const index_t StrideE, - const index_t KBatch, - const Block2ETileMap& block_2_etile_map) + __device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) { const auto p_a_grid = reinterpret_cast(p_a_grid_); const auto p_b_grid = reinterpret_cast(p_b_grid_); @@ -1035,7 +1047,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle if(kbatch_id == KBatch - 1) { - Run( + Run( p_a_grid, p_b_grid, p_ds_grid, @@ -1054,7 +1066,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle } else { - Run>( + Run, true>( p_a_grid, p_b_grid, p_ds_grid, @@ -1072,6 +1084,89 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle block_2_etile_map); } } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t*, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + nullptr, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } }; } // namespace ck -- GitLab From ae57e5938e7fdfd049055a855910f66054e04163 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:42:17 -0700 Subject: [PATCH 15/63] Split the instances by architecture. (#1223) * parse examples inside the add_example_executable function * fix the example 64 cmake file * add xdl flag to the gemm_bias_softmax_gemm_permute example * add filtering of tests based on architecture type * enable test_grouped_gemm for gfx9 only * enable test_transpose only for gfx9 * only linnk test_transpose if it gets built * split the gemm instances by architectures * split gemm_bilinear,grouped_conv_bwd_weight instances by targets * split instances by architecture * split grouped_conv instances by architecture * fix clang format * fix the if-else logic in group_conv headers * small fix for grouped convolution instances * fix the grouped conv bwd weight dl instances * fix client examples * only enable client examples 3 and 4 on gfx9 * set the gfx9 macro * make sure the architecture macros are set by cmake * use separate set of xdl/wmma flags for host code * sinmplify the main cmake file * add conv_fwd_bf8 instance declaration --- CMakeLists.txt | 16 + .../02_gemm_add_add_fastgelu/CMakeLists.txt | 34 +- .../03_gemm_layernorm/CMakeLists.txt | 10 +- client_example/04_contraction/CMakeLists.txt | 23 +- .../07_grouped_convnd_fwd/CMakeLists.txt | 10 +- .../08_fused_attention/CMakeLists.txt | 10 +- client_example/09_quantization/CMakeLists.txt | 30 +- .../15_convnd_bwd_data/CMakeLists.txt | 10 +- .../15_gemm_add_multiply/CMakeLists.txt | 7 +- .../17_grouped_gemm_fastgelu/CMakeLists.txt | 6 +- client_example/20_splitk_gemm/CMakeLists.txt | 2 +- .../21_grouped_gemm_bias/CMakeLists.txt | 6 +- client_example/22_grouped_gemm/CMakeLists.txt | 18 +- .../24_grouped_conv_activation/CMakeLists.txt | 2 + client_example/25_wrapper/CMakeLists.txt | 4 +- client_example/CMakeLists.txt | 5 +- example/01_gemm/CMakeLists.txt | 11 +- example/02_gemm_bilinear/CMakeLists.txt | 23 +- example/03_gemm_bias_relu/CMakeLists.txt | 9 +- .../04_gemm_add_add_fastgelu/CMakeLists.txt | 35 +- example/09_convnd_fwd/CMakeLists.txt | 23 +- .../CMakeLists.txt | 34 +- example/14_gemm_quantization/CMakeLists.txt | 13 +- .../CMakeLists.txt | 89 +- example/17_convnd_bwd_data/CMakeLists.txt | 15 +- .../20_grouped_conv_bwd_weight/CMakeLists.txt | 34 +- example/21_gemm_layernorm/CMakeLists.txt | 16 +- example/26_contraction/CMakeLists.txt | 32 +- .../CMakeLists.txt | 5 +- .../CMakeLists.txt | 51 +- example/31_batched_gemm_gemm/CMakeLists.txt | 20 +- .../CMakeLists.txt | 14 +- example/35_splitK_gemm/CMakeLists.txt | 43 +- .../CMakeLists.txt | 31 +- .../40_conv2d_fwd_quantization/CMakeLists.txt | 39 +- .../41_grouped_conv_conv_fwd/CMakeLists.txt | 20 +- .../CMakeLists.txt | 9 +- ...=> gemm_bias_softmax_gemm_permute_xdl.cpp} | 0 example/52_im2col_col2im/CMakeLists.txt | 18 +- example/60_gemm_multi_ABD/CMakeLists.txt | 9 +- .../61_contraction_multi_ABD/CMakeLists.txt | 9 +- example/62_convnd_activ/CMakeLists.txt | 19 +- example/64_fpAintB_gemm/CMakeLists.txt | 8 +- example/CMakeLists.txt | 36 + include/ck/ck.hpp | 14 +- .../tensor_operation_instance/gpu/gemm.hpp | 581 ++------- .../gpu/gemm_bilinear.hpp | 8 +- .../tensor_operation_instance/gpu/gemm_dl.inc | 167 +++ .../gpu/gemm_wmma.inc | 34 + .../gpu/gemm_xdl.inc | 238 ++++ .../gpu/grouped_convolution_backward_data.hpp | 661 +++------- ...grouped_convolution_backward_data_wmma.inc | 243 ++++ .../grouped_convolution_backward_data_xdl.inc | 216 ++++ .../grouped_convolution_backward_weight.hpp | 849 ++++--------- ...grouped_convolution_backward_weight_dl.inc | 243 ++++ ...ouped_convolution_backward_weight_wmma.inc | 114 ++ ...rouped_convolution_backward_weight_xdl.inc | 228 ++++ .../gpu/grouped_convolution_forward.hpp | 1121 +++-------------- .../gpu/grouped_convolution_forward_dl.inc | 73 ++ .../gpu/grouped_convolution_forward_wmma.inc | 480 +++++++ .../gpu/grouped_convolution_forward_xdl.inc | 357 ++++++ .../gpu/CMakeLists.txt | 35 + .../gpu/batched_gemm/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../batched_gemm_bias_permute/CMakeLists.txt | 1 + .../gpu/batched_gemm_gemm/CMakeLists.txt | 1 + .../gpu/batched_gemm_reduce/CMakeLists.txt | 1 + .../batched_gemm_softmax_gemm/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/contraction_bilinear/CMakeLists.txt | 1 + .../gpu/contraction_scale/CMakeLists.txt | 1 + .../gpu/conv1d_bwd_data/CMakeLists.txt | 1 + .../gpu/conv2d_bwd_data/CMakeLists.txt | 1 + .../gpu/conv2d_fwd/CMakeLists.txt | 1 + .../gpu/conv2d_fwd_bias_relu/CMakeLists.txt | 1 + .../conv2d_fwd_bias_relu_add/CMakeLists.txt | 1 + .../gpu/conv3d_bwd_data/CMakeLists.txt | 1 + .../gpu/gemm_add/CMakeLists.txt | 1 + .../gpu/gemm_add_add_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_add_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_add_multiply/CMakeLists.txt | 1 + .../gpu/gemm_add_relu/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/gemm_add_silu/CMakeLists.txt | 1 + .../gpu/gemm_bias_add_reduce/CMakeLists.txt | 1 + .../gpu/gemm_bilinear/CMakeLists.txt | 1 + .../gpu/gemm_fastgelu/CMakeLists.txt | 1 + .../gpu/gemm_multiply_add/CMakeLists.txt | 1 + .../gpu/gemm_reduce/CMakeLists.txt | 1 + .../gpu/gemm_splitk/CMakeLists.txt | 1 + .../gpu/gemm_streamk/CMakeLists.txt | 1 + .../grouped_conv1d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv1d_fwd/CMakeLists.txt | 1 + .../grouped_conv2d_bwd_data/CMakeLists.txt | 1 + .../grouped_conv2d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 1 + .../grouped_conv3d_bwd_data/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../grouped_conv3d_bwd_weight/CMakeLists.txt | 1 + .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../grouped_conv3d_fwd_scale/CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../CMakeLists.txt | 1 + .../gpu/grouped_gemm/CMakeLists.txt | 1 + .../gpu/grouped_gemm_bias/CMakeLists.txt | 1 + .../gpu/grouped_gemm_fastgelu/CMakeLists.txt | 1 + .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 1 + .../gpu/quantization/CMakeLists.txt | 1 + profiler/src/CMakeLists.txt | 167 +-- test/CMakeLists.txt | 25 +- test/batched_gemm/CMakeLists.txt | 11 +- ...hed_gemm.cpp => test_batched_gemm_xdl.cpp} | 0 test/batched_gemm_gemm/CMakeLists.txt | 19 +- ...pp => test_batched_gemm_gemm_fp16_xdl.cpp} | 0 test/batched_gemm_reduce/CMakeLists.txt | 13 +- ...6.cpp => batched_gemm_reduce_fp16_xdl.cpp} | 0 test/batched_gemm_softmax_gemm/CMakeLists.txt | 19 +- ...st_batched_gemm_softmax_gemm_fp16_xdl.cpp} | 0 .../CMakeLists.txt | 50 +- ...mm_bias_softmax_gemm_permute_bf16_xdl.cpp} | 0 ...mm_bias_softmax_gemm_permute_fp16_xdl.cpp} | 0 ...ed_gemm_softmax_gemm_permute_bf16_xdl.cpp} | 0 ...ed_gemm_softmax_gemm_permute_fp16_xdl.cpp} | 0 test/contraction/CMakeLists.txt | 21 +- ...cpp => test_contraction_interface_xdl.cpp} | 0 ...ntraction.cpp => test_contraction_xdl.cpp} | 0 test/convnd_bwd_data/CMakeLists.txt | 11 +- ...d_bwd_data.cpp => convnd_bwd_data_xdl.cpp} | 0 test/convnd_fwd/CMakeLists.txt | 11 +- .../{convnd_fwd.cpp => convnd_fwd_xdl.cpp} | 0 test/gemm_add/CMakeLists.txt | 24 +- ...elu.cpp => test_gemm_add_fastgelu_xdl.cpp} | 2 +- ...dd_relu.cpp => test_gemm_add_relu_xdl.cpp} | 2 +- ...dd_silu.cpp => test_gemm_add_silu_xdl.cpp} | 2 +- ...est_gemm_add.hpp => test_gemm_add_xdl.hpp} | 0 test/gemm_layernorm/CMakeLists.txt | 19 +- ..._gemm_add_relu_add_layernorm_fp16_xdl.cpp} | 0 test/gemm_reduce/CMakeLists.txt | 2 +- ...duce_fp16.cpp => gemm_reduce_fp16_xdl.cpp} | 0 test/gemm_split_k/CMakeLists.txt | 9 +- ...mm_splitk.cpp => test_gemm_splitk_xdl.cpp} | 0 test/grouped_convnd_bwd_data/CMakeLists.txt | 31 +- ...test_grouped_convnd_bwd_data_xdl_wmma.cpp} | 0 test/grouped_convnd_bwd_weight/CMakeLists.txt | 26 +- ...st_grouped_convnd_bwd_weight_xdl_wmma.cpp} | 0 test/grouped_convnd_fwd/CMakeLists.txt | 16 +- ...ti_d_interface_compatibility_xdl_wmma.cpp} | 0 ...p => test_grouped_convnd_fwd_xdl_wmma.cpp} | 0 test/grouped_gemm/CMakeLists.txt | 27 +- ...pp => test_grouped_gemm_interface_xdl.cpp} | 0 ...k.cpp => test_grouped_gemm_splitk_xdl.cpp} | 0 test/normalization_bwd_data/CMakeLists.txt | 13 +- .../CMakeLists.txt | 13 +- test/permute_scale/CMakeLists.txt | 6 +- test/transpose/CMakeLists.txt | 13 +- ...t_transpose.cpp => test_transpose_xdl.cpp} | 0 test/wrapper/CMakeLists.txt | 6 +- ...per_gemm.cpp => test_wrapper_gemm_xdl.cpp} | 0 160 files changed, 3752 insertions(+), 3374 deletions(-) rename example/47_gemm_bias_softmax_gemm_permute/{gemm_bias_softmax_gemm_permute.cpp => gemm_bias_softmax_gemm_permute_xdl.cpp} (100%) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_dl.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc rename test/batched_gemm/{test_batched_gemm.cpp => test_batched_gemm_xdl.cpp} (100%) rename test/batched_gemm_gemm/{test_batched_gemm_gemm_fp16.cpp => test_batched_gemm_gemm_fp16_xdl.cpp} (100%) rename test/batched_gemm_reduce/{batched_gemm_reduce_fp16.cpp => batched_gemm_reduce_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm/{test_batched_gemm_softmax_gemm_fp16.cpp => test_batched_gemm_softmax_gemm_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp => test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp => test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_softmax_gemm_permute_bf16.cpp => test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp} (100%) rename test/batched_gemm_softmax_gemm_permute/{test_batched_gemm_softmax_gemm_permute_fp16.cpp => test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp} (100%) rename test/contraction/{test_contraction_interface.cpp => test_contraction_interface_xdl.cpp} (100%) rename test/contraction/{test_contraction.cpp => test_contraction_xdl.cpp} (100%) rename test/convnd_bwd_data/{convnd_bwd_data.cpp => convnd_bwd_data_xdl.cpp} (100%) rename test/convnd_fwd/{convnd_fwd.cpp => convnd_fwd_xdl.cpp} (100%) rename test/gemm_add/{test_gemm_add_fastgelu.cpp => test_gemm_add_fastgelu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add_relu.cpp => test_gemm_add_relu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add_silu.cpp => test_gemm_add_silu_xdl.cpp} (98%) rename test/gemm_add/{test_gemm_add.hpp => test_gemm_add_xdl.hpp} (100%) rename test/gemm_layernorm/{test_gemm_add_relu_add_layernorm_fp16.cpp => test_gemm_add_relu_add_layernorm_fp16_xdl.cpp} (100%) rename test/gemm_reduce/{gemm_reduce_fp16.cpp => gemm_reduce_fp16_xdl.cpp} (100%) rename test/gemm_split_k/{test_gemm_splitk.cpp => test_gemm_splitk_xdl.cpp} (100%) rename test/grouped_convnd_bwd_data/{test_grouped_convnd_bwd_data.cpp => test_grouped_convnd_bwd_data_xdl_wmma.cpp} (100%) rename test/grouped_convnd_bwd_weight/{test_grouped_convnd_bwd_weight.cpp => test_grouped_convnd_bwd_weight_xdl_wmma.cpp} (100%) rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp => test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp} (100%) rename test/grouped_convnd_fwd/{test_grouped_convnd_fwd.cpp => test_grouped_convnd_fwd_xdl_wmma.cpp} (100%) rename test/grouped_gemm/{test_grouped_gemm_interface.cpp => test_grouped_gemm_interface_xdl.cpp} (100%) rename test/grouped_gemm/{test_grouped_gemm_splitk.cpp => test_grouped_gemm_splitk_xdl.cpp} (100%) rename test/transpose/{test_transpose.cpp => test_transpose_xdl.cpp} (100%) rename test/wrapper/{test_wrapper_gemm.cpp => test_wrapper_gemm_xdl.cpp} (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index bdeba33ea..3c77f520a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,22 @@ if(GPU_TARGETS) else() message("Building CK for the following targets: ${AMDGPU_TARGETS}") endif() + +if (GPU_TARGETS) + if (GPU_TARGETS MATCHES "gfx9") + add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") + endif() + if (GPU_TARGETS MATCHES "gfx11") + add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") + endif() +else() + add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") +endif() + find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 diff --git a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt index 772b69995..4ba86026b 100644 --- a/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt +++ b/client_example/02_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,27 +1,29 @@ -add_custom_target(client_gemm_fastgelu_examples) +if(GPU_TARGETS MATCHES "gfx9") + add_custom_target(client_gemm_fastgelu_examples) -add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu gemm_add_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) -target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu gemm_add_fastgelu.cpp) + target_link_libraries(client_gemm_add_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) -target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu gemm_fastgelu.cpp) + target_link_libraries(client_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu + add_dependencies(client_gemm_fastgelu_examples client_gemm_add_add_fastgelu client_gemm_add_fastgelu client_gemm_fastgelu) -add_custom_target(client_gemm_fastgelu_generic_examples) + add_custom_target(client_gemm_fastgelu_generic_examples) -add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_add_fastgelu_generic gemm_add_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_add_fastgelu_generic composable_kernel::device_gemm_operations) -add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) -target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_add_fastgelu_generic gemm_add_fastgelu_generic.cpp) + target_link_libraries(client_gemm_add_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) -target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_gemm_fastgelu_generic gemm_fastgelu_generic.cpp) + target_link_libraries(client_gemm_fastgelu_generic PRIVATE composable_kernel::device_gemm_operations) -add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic + add_dependencies(client_gemm_fastgelu_generic_examples client_gemm_add_add_fastgelu_generic client_gemm_add_fastgelu_generic client_gemm_fastgelu_generic) +endif() diff --git a/client_example/03_gemm_layernorm/CMakeLists.txt b/client_example/03_gemm_layernorm/CMakeLists.txt index 94b4576f6..8fedc8463 100644 --- a/client_example/03_gemm_layernorm/CMakeLists.txt +++ b/client_example/03_gemm_layernorm/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) -target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp) + target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) -add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) -target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) + add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp) + target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations) +endif() diff --git a/client_example/04_contraction/CMakeLists.txt b/client_example/04_contraction/CMakeLists.txt index cd4a95124..13c037584 100644 --- a/client_example/04_contraction/CMakeLists.txt +++ b/client_example/04_contraction/CMakeLists.txt @@ -1,15 +1,16 @@ -add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) -target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp) + target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) -target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp32 contraction_bilinear_fp32.cpp) + target_link_libraries(client_contraction_bilinear_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) -target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_scale_fp64 contraction_scale_fp64.cpp) + target_link_libraries(client_contraction_scale_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) -add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) -target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) - -add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) -target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(client_contraction_bilinear_fp64 contraction_bilinear_fp64.cpp) + target_link_libraries(client_contraction_bilinear_fp64 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) + add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp) + target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/07_grouped_convnd_fwd/CMakeLists.txt b/client_example/07_grouped_convnd_fwd/CMakeLists.txt index 40f1bba06..710eca9f4 100644 --- a/client_example/07_grouped_convnd_fwd/CMakeLists.txt +++ b/client_example/07_grouped_convnd_fwd/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) -target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp) + target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations) -add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) -target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) + add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp) + target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations) +endif() \ No newline at end of file diff --git a/client_example/08_fused_attention/CMakeLists.txt b/client_example/08_fused_attention/CMakeLists.txt index 9472be07b..4bcde367d 100644 --- a/client_example/08_fused_attention/CMakeLists.txt +++ b/client_example/08_fused_attention/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_fused_attention fused_attention.cpp) -target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_fused_attention fused_attention.cpp) + target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_fused_attention_bias fused_attention_bias.cpp) -target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_fused_attention_bias fused_attention_bias.cpp) + target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/09_quantization/CMakeLists.txt b/client_example/09_quantization/CMakeLists.txt index 65ad642ce..d2d3a427e 100644 --- a/client_example/09_quantization/CMakeLists.txt +++ b/client_example/09_quantization/CMakeLists.txt @@ -1,22 +1,22 @@ -if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) -add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)) + add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perchannel_quantization conv2d_fwd_bias_relu_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_tanh_perlayer_quantization conv2d_fwd_bias_tanh_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_tanh_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_bias_relu_perlayer_quantization conv2d_fwd_bias_relu_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_bias_relu_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perchannel_quantization conv2d_fwd_perchannel_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) -target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_conv2d_fwd_perlayer_quantization conv2d_fwd_perlayer_quantization.cpp) + target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) -add_executable(client_gemm_quantization gemm_quantization.cpp) -target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) + add_executable(client_gemm_quantization gemm_quantization.cpp) + target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations) endif() diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt index f35cd82d7..8fc62bc2b 100644 --- a/client_example/15_convnd_bwd_data/CMakeLists.txt +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,7 @@ -add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) -add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) + add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) -target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) -target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_conv_operations) + target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/15_gemm_add_multiply/CMakeLists.txt b/client_example/15_gemm_add_multiply/CMakeLists.txt index 4b4d76200..a683f7857 100644 --- a/client_example/15_gemm_add_multiply/CMakeLists.txt +++ b/client_example/15_gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ - -add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) -target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_gemm_add_multiply gemm_add_multiply.cpp) + target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt index fd315afbd..39bef7181 100644 --- a/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt +++ b/client_example/17_grouped_gemm_fastgelu/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) -target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) \ No newline at end of file +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp) + target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/20_splitk_gemm/CMakeLists.txt b/client_example/20_splitk_gemm/CMakeLists.txt index a3dc85376..05fcaa810 100644 --- a/client_example/20_splitk_gemm/CMakeLists.txt +++ b/client_example/20_splitk_gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) +if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)) add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp) target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations) endif() diff --git a/client_example/21_grouped_gemm_bias/CMakeLists.txt b/client_example/21_grouped_gemm_bias/CMakeLists.txt index 92e31495c..a09921e50 100644 --- a/client_example/21_grouped_gemm_bias/CMakeLists.txt +++ b/client_example/21_grouped_gemm_bias/CMakeLists.txt @@ -1,2 +1,4 @@ -add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt index 0c3cb956f..1e1c39681 100644 --- a/client_example/22_grouped_gemm/CMakeLists.txt +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -1,11 +1,13 @@ -add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) +if(GPU_TARGETS MATCHES "gfx9") + add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_fp8 grouped_gemm_fixed_nk_fp8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) -add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) -target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) + add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) + target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) +endif() diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index 074dcd9b9..e79dee9f7 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -1,3 +1,4 @@ +if(GPU_TARGETS MATCHES "gfx9") # Fwd scaleadd scaleadd relu add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) @@ -46,3 +47,4 @@ target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_ke add_executable(client_grouped_convnd_bwd_data_scale_fp16 grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp) target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +endif() diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt index fdfc1d8d2..b1e9d20bf 100644 --- a/client_example/25_wrapper/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -2,9 +2,7 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_img2col wrapper_img2col.cpp) target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") +if(GPU_TARGETS MATCHES "gfx9") add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 753f5e5ae..3aa9efa31 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -48,7 +48,10 @@ else() endif() endif() -find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_contraction_operations device_reduction_operations) +find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) +if(GPU_TARGETS MATCHES "gfx9") + find_package(composable_kernel COMPONENTS device_contraction_operations) +endif() find_package(hip REQUIRED PATHS /opt/rocm) message(STATUS "Build with HIP ${hip_VERSION}") diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2fa8e7746..39e3f2a2b 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -27,11 +27,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_gemm_wmma) - add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) - add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) -endif() add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16) @@ -47,8 +42,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_gemm_xdl example_gemm_xdl_int4) endif(USE_BITINT_EXTENSION_INT4) -# FIXME: re-enable this example as test when SWDEV-335738 is fixed -add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) @@ -75,3 +69,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) +add_custom_target(example_gemm_wmma) +add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) diff --git a/example/02_gemm_bilinear/CMakeLists.txt b/example/02_gemm_bilinear/CMakeLists.txt index d82c42d5a..2c20b96ee 100644 --- a/example/02_gemm_bilinear/CMakeLists.txt +++ b/example/02_gemm_bilinear/CMakeLists.txt @@ -1,20 +1,3 @@ -list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) - add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) -endif() -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940") - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp) +add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp) +add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp) diff --git a/example/03_gemm_bias_relu/CMakeLists.txt b/example/03_gemm_bias_relu/CMakeLists.txt index 2f5cba924..35c54abac 100644 --- a/example/03_gemm_bias_relu/CMakeLists.txt +++ b/example/03_gemm_bias_relu/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_relu_xdl_fp16 gemm_bias_relu_xdl_fp16.cpp) diff --git a/example/04_gemm_add_add_fastgelu/CMakeLists.txt b/example/04_gemm_add_add_fastgelu/CMakeLists.txt index 33ac1e7e7..ab19f819e 100644 --- a/example/04_gemm_add_add_fastgelu/CMakeLists.txt +++ b/example/04_gemm_add_add_fastgelu/CMakeLists.txt @@ -1,29 +1,20 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_add_add_fastgelu_xdl) - add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) +add_custom_target(example_gemm_add_add_fastgelu_xdl) +add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16) - add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) +add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - - add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) - add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp) +add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8) -set(gpu_list "") +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) + add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) set(target 0) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 195f1857e..61e9a43c3 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,19 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) - add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) - add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) - add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) - add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) - add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) - # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed - add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) +add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt index 222a3b7c0..ef8bea185 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/CMakeLists.txt @@ -1,25 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_reduce_xdl) +add_custom_target(example_convnd_fwd_reduce_xdl) +add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) - add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16) +add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) - add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16) +add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) +add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) - add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp) + add_example_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4) +endif(USE_BITINT_EXTENSION_INT4) diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 9793e8b8a..8703fa3ed 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,12 +1,3 @@ -# dlops add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) -# xdlops -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) - add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) +add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt index 5955e1d6c..1e12c16f3 100644 --- a/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt +++ b/example/16_gemm_multi_d_multi_reduces/CMakeLists.txt @@ -1,48 +1,41 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_gemm_reduce_xdl) - add_custom_target(example_gemm_reduce_xdl_max) - add_custom_target(example_gemm_reduce_xdl_mean_meansquare) - add_custom_target(example_gemm_add_add_mean_meansquare_xdl) - - add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) - - add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) - - add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) - - add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) - - add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) - - add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) - - add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) - - add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) - - add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) - add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) - - add_example_dependencies(example_gemm_reduce_xdl - example_gemm_reduce_xdl_mean_meansquare - example_gemm_reduce_xdl_max - example_gemm_add_add_mean_meansquare_xdl) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) - add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) - endif() - set(target 1) - endif() -endforeach() +add_custom_target(example_gemm_reduce_xdl) +add_custom_target(example_gemm_reduce_xdl_max) +add_custom_target(example_gemm_reduce_xdl_mean_meansquare) +add_custom_target(example_gemm_add_add_mean_meansquare_xdl) + +add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16) + +add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16) + +add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16) + +add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8) + +add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8) + +add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32) + +add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32) + +add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16) + +add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp) +add_example_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16) + +add_example_dependencies(example_gemm_reduce_xdl + example_gemm_reduce_xdl_mean_meansquare + example_gemm_reduce_xdl_max + example_gemm_add_add_mean_meansquare_xdl) + +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp) + add_example_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4) +endif() diff --git a/example/17_convnd_bwd_data/CMakeLists.txt b/example/17_convnd_bwd_data/CMakeLists.txt index 7c6d10d8a..39f9fb8ec 100644 --- a/example/17_convnd_bwd_data/CMakeLists.txt +++ b/example/17_convnd_bwd_data/CMakeLists.txt @@ -1,14 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) - endif() - set(target 1) - endif() -endforeach() +add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp) +if(result EQUAL 0) + target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility) +endif() add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp) if(result EQUAL 0) diff --git a/example/20_grouped_conv_bwd_weight/CMakeLists.txt b/example/20_grouped_conv_bwd_weight/CMakeLists.txt index c28fca6fa..497ea19e1 100644 --- a/example/20_grouped_conv_bwd_weight/CMakeLists.txt +++ b/example/20_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,29 +1,15 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) +add_custom_target(example_grouped_conv_bwd_weight) +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) +add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16) - add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - set(target 1) - endif() +add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16_comp_bf8_fp8) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_weight) - add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) - set(target 1) - endif() -endforeach() - -add_custom_target(example_grouped_conv_bwd_weight_dl) +add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16) add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp) -add_example_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16) +add_example_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_dl_fp16) diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt index e231bc619..2eb7052e1 100644 --- a/example/21_gemm_layernorm/CMakeLists.txt +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -1,12 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) - add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) - add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) - set(target 1) - endif() -endforeach() - +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_welford_fp16 gemm_bias_relu_add_layernorm_xdl_welford_fp16.cpp) +add_example_executable(example_gemm_bias_relu_add_layernorm_xdl_naive_fp16 gemm_bias_relu_add_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_layernorm_xdl_naive_fp16 gemm_layernorm_xdl_naive_fp16.cpp) +add_example_executable(example_gemm_xdl_layernorm_naive_single_kernel_fp16 gemm_xdl_layernorm_naive_single_kernel_fp16.cpp) diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 1a0489ce9..f3d30cea2 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -4,49 +4,49 @@ add_custom_target(example_contraction_bilinear) # FP32 add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) # FP64 add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) # FP16 add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) # BF16 add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) +add_example_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) -add_dependencies(example_contraction example_contraction_scale) -add_dependencies(example_contraction example_contraction_bilinear) +add_example_dependencies(example_contraction example_contraction_scale) +add_example_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index f343cc191..ac54aebdc 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,2 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) - -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) diff --git a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt index 3a8c2ef52..7acb1a190 100644 --- a/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt +++ b/example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt @@ -1,40 +1,23 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx1100 gfx1101 gfx1102) +add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_custom_target(example_grouped_conv_fwd_multiple_d) +add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) - add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) +add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) + add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) +endif() # USE_BITINT_EXTENSION_INT4 - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp) - add_example_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4) - endif() # USE_BITINT_EXTENSION_INT4 - - set(target 1) - endif() -endforeach() - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) - add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) +add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp) diff --git a/example/31_batched_gemm_gemm/CMakeLists.txt b/example/31_batched_gemm_gemm/CMakeLists.txt index 93f16c945..8b648a7f7 100644 --- a/example/31_batched_gemm_gemm/CMakeLists.txt +++ b/example/31_batched_gemm_gemm/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) - add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp) +add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp) diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index c6cca7b58..519f53910 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,11 +1,9 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) - add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) - add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) - add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) - add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) -endif() +add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) +add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) +add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) +add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) +add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) add_custom_target(example_gemm_scale_softmax_gemm) diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index 5277b32f6..9a62d85ac 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -1,32 +1,23 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_splitK_gemm_xdl) +add_custom_target(example_splitK_gemm_xdl) +add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) - add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32) +add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) - add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) +add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) - add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) +add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) - add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) +add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) - add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16) +add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) +add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8) - - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) - add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) - endif() - - set(target 1) - endif() -endforeach() +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4) +endif() diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 1ae179e95..72e695964 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -1,27 +1,10 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) +add_custom_target(example_grouped_conv_bwd_data) - add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) - add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) - set(target 1) - endif() -endforeach() - -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_custom_target(example_grouped_conv_bwd_data) - - add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) - add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) - - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16) diff --git a/example/40_conv2d_fwd_quantization/CMakeLists.txt b/example/40_conv2d_fwd_quantization/CMakeLists.txt index 2d804cafe..991c1e464 100644 --- a/example/40_conv2d_fwd_quantization/CMakeLists.txt +++ b/example/40_conv2d_fwd_quantization/CMakeLists.txt @@ -1,24 +1,17 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) - add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_conv2d_fwd_xdl_perlayer_quantization_int8 conv2d_fwd_xdl_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_perchannel_quantization_int8 conv2d_fwd_xdl_perchannel_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8 conv2d_fwd_xdl_bias_relu_perlayer_quantization_int8.cpp) +add_example_executable(example_conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8 conv2d_fwd_xdl_bias_relu_perchannel_quantization_int8.cpp) - # Conv perlayer quantization - add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) - # Conv perchannel quantization - add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) - # Conv + bias + relu perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) - # Conv + bias + relu perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) - # Conv + bias + tanh perlayer quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) - # Conv + bias + tanh perchannel quantization - add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) +# Conv perlayer quantization +add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp) +# Conv perchannel quantization +add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp) +# Conv + bias + relu perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp) +# Conv + bias + relu perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp) +# Conv + bias + tanh perlayer quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp) +# Conv + bias + tanh perchannel quantization +add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp) diff --git a/example/41_grouped_conv_conv_fwd/CMakeLists.txt b/example/41_grouped_conv_conv_fwd/CMakeLists.txt index ae251e88d..8ab56b21a 100644 --- a/example/41_grouped_conv_conv_fwd/CMakeLists.txt +++ b/example/41_grouped_conv_conv_fwd/CMakeLists.txt @@ -1,17 +1,9 @@ -list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list2 gfx908 gfx90a) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list1 AND target EQUAL 0) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) - add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) - if(USE_BITINT_EXTENSION_INT4) - add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) - endif(USE_BITINT_EXTENSION_INT4) - set(target 1) - endif() -endforeach() +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp) +add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp) +if(USE_BITINT_EXTENSION_INT4) + add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp) +endif(USE_BITINT_EXTENSION_INT4) if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1") add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt index 14432f6e2..df1956ca6 100644 --- a/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt +++ b/example/47_gemm_bias_softmax_gemm_permute/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp) diff --git a/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp b/example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp similarity index 100% rename from example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute.cpp rename to example/47_gemm_bias_softmax_gemm_permute/gemm_bias_softmax_gemm_permute_xdl.cpp diff --git a/example/52_im2col_col2im/CMakeLists.txt b/example/52_im2col_col2im/CMakeLists.txt index 4dc6c8b4e..63ee1d431 100644 --- a/example/52_im2col_col2im/CMakeLists.txt +++ b/example/52_im2col_col2im/CMakeLists.txt @@ -1,15 +1,7 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_im2col_col2im) +add_custom_target(example_im2col_col2im) - add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) - add_example_dependencies(example_im2col_col2im example_image_to_column_f32) +add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) +add_example_dependencies(example_im2col_col2im example_image_to_column_f32) - add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) - add_example_dependencies(example_im2col_col2im example_column_to_image_f32) - - set(target 1) - endif() -endforeach() +add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) +add_example_dependencies(example_im2col_col2im example_column_to_image_f32) diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt index 57bc0b33e..d3974897f 100644 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ b/example/60_gemm_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt index 42500b64e..1b8bd4cad 100644 --- a/example/61_contraction_multi_ABD/CMakeLists.txt +++ b/example/61_contraction_multi_ABD/CMakeLists.txt @@ -1,8 +1 @@ -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() +add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt index 6eaddd3ff..5a35f9b60 100644 --- a/example/62_convnd_activ/CMakeLists.txt +++ b/example/62_convnd_activ/CMakeLists.txt @@ -2,16 +2,9 @@ add_subdirectory(binary) add_subdirectory(multi_AB) add_subdirectory(unary) -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_activ_xdl) - # ScaleAdd ScaleAdd Relu - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) - add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) - set(target 1) - endif() -endforeach() +add_custom_target(example_convnd_activ_xdl) +# ScaleAdd ScaleAdd Relu +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) +add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) +add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt index 89cc2d7f6..b3c77b3bd 100644 --- a/example/64_fpAintB_gemm/CMakeLists.txt +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -1,5 +1,3 @@ -if(GPU_TARGETS MATCHES "gfx11") - add_custom_target(example_fpAintB_gemm_wmma) - add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) - add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) -endif() +add_custom_target(example_fpAintB_gemm_wmma) +add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) +add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c19ba93b6..5465adb77 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -5,6 +5,12 @@ include_directories(BEFORE add_custom_target(examples) +function(add_example_dependencies EXAMPLE_NAME FILE_NAME) + if(FILE_NAME) + add_dependencies(EXAMPLE_NAME FILE_NAME) + endif() +endfunction(add_example_dependencies EXAMPLE_NAME) + function(add_example_executable EXAMPLE_NAME FILE_NAME) message("adding example ${EXAMPLE_NAME}") set(result 1) @@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) endif() endforeach() endif() + #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() endforeach() + #Do not build any XDL examples if gfx9 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() + #Do not build any WMMA examples if gfx11 targets are not on the list + foreach(source IN LISTS FILE_NAME) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(FILE_NAME) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index c93d1d063..0bda8b759 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -45,6 +45,10 @@ #endif // define general macros for various architectures +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) +#define __gfx9__ +#endif #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #define __gfx94__ #endif @@ -62,8 +66,7 @@ // buffer resource #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_BUFFER_RESOURCE_3RD_DWORD -1 -#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #elif defined(__gfx103__) #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 @@ -75,8 +78,7 @@ #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #define CK_USE_AMD_V_MAC_F32 -#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ - defined(__gfx94__) // for GPU code +#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code #define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT4_I32_I8 @@ -89,7 +91,7 @@ // MFMA instruction #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_MFMA -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_MFMA #endif @@ -120,7 +122,7 @@ // buffer atomic add: floating point #ifndef __HIP_DEVICE_COMPILE__ // for host code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 -#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code +#elif defined(__gfx9__) // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #else // for GPU code #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index ee9d97709..50c18fc22 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -12,397 +12,20 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS) -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS) -void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS) -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector>>& - instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( - std::vector>>& - instances); +#ifdef DL_KERNELS +#include "gemm_dl.inc" #endif -#ifdef CK_ENABLE_FP64 -void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( - std::vector>>& - instances); +#ifdef CK_USE_WMMA +#include "gemm_wmma.inc" #endif -#ifdef CK_ENABLE_FP8 -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( - std::vector>>& instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( - std::vector>>& - instances); +#ifdef CK_USE_XDL +#include "gemm_xdl.inc" #endif -void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& - instances); - -void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& - instances); +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { template > op_ptrs; +#ifdef DL_KERNELS if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs); -#endif - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); - add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( - op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs); -#endif - add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); - add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( - op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs); -#endif - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); - add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( - op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs); -#endif - add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); - add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( - op_ptrs); } } #ifdef CK_ENABLE_FP16 @@ -490,60 +90,160 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs); -#endif - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs); -#endif - add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( - op_ptrs); - add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); -#endif - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); - add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) { - /// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs); add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); + } + } #endif - add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); + add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); + } + } +#endif +#endif // DL_KERNELS + +#ifdef CK_USE_WMMA +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } #endif +#endif + +#ifdef CK_USE_XDL + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( + op_ptrs); + } + } +#ifdef CK_ENABLE_FP16 + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif #ifdef CK_ENABLE_BF16 else if constexpr(is_same_v && is_same_v && is_same_v) @@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs); -#endif } else if constexpr(is_same_v && is_same_v && is_same_v) { add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs); -#ifdef DL_KERNELS - add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs); - add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs); -#endif } } #endif @@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs); } } +#endif #endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 1a518a530..6ee88bd85 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -16,7 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); #endif -#ifdef CK_ENABLE_INT8 +#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA) void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector> op_ptrs; -#ifdef CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL) if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v && is_same_v) { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc new file mode 100644 index 000000000..44a11f628 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_dl.inc @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#if defined(CK_ENABLE_FP16) +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_FP32) +void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#if defined(CK_ENABLE_INT8) +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc new file mode 100644 index 000000000..c97298c25 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_wmma.inc @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc new file mode 100644 index 000000000..82a1dc425 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_xdl.inc @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP64 +void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( + std::vector>>& + instances); +#endif +#ifdef CK_ENABLE_FP8 +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& + instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index 09885ccd9..9a70a4727 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -10,439 +10,18 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_data_xdl.inc" +#endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_data_wmma.inc" +#endif + namespace ck { namespace tensor_operation { namespace device { namespace instance { -// conv2d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); - -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -// conv3d backward data -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( - std::vector>>& instances); -#endif template > op_ptrs; + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 2) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && @@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); - add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } } - else if constexpr(NumDimSpatial == 3) + if constexpr(NumDimSpatial == 3) { - if constexpr(is_same_v && is_same_v && is_same_v) { @@ -593,45 +142,29 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( op_ptrs); } -#endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) - { - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( - op_ptrs); - } #endif } - else if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && @@ -640,44 +173,139 @@ struct DeviceOperationInstanceFactory< { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( - op_ptrs); - add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( - op_ptrs); } #endif #if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( op_ptrs); } #endif #ifdef CK_ENABLE_FP32 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( op_ptrs); } #endif #ifdef CK_ENABLE_BF16 - else if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( op_ptrs); } +#endif + } + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + op_ptrs); + } #endif #ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs); + add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + op_ptrs); + } +#endif + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( op_ptrs); @@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory< #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc new file mode 100644 index 000000000..fb2407bcc --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_wmma.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv2d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc new file mode 100644 index 000000000..7ad021841 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances( + std::vector>>& instances); + +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); +#endif + +// conv3d backward data +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index b8ca2c5fa..dc56b8f4b 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -12,564 +12,19 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// xdl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 -void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - #ifdef DL_KERNELS -// dl -// conv1d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( - std::vector>>& instances); +#include "grouped_convolution_backward_weight_dl.inc" #endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( - std::vector>>& instances); -#endif -// conv2d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif -// conv3d backward weight -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); +#ifdef CK_USE_XDL +#include "grouped_convolution_backward_weight_xdl.inc" #endif +#ifdef CK_USE_WMMA +#include "grouped_convolution_backward_weight_wmma.inc" #endif +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { template > op_ptrs; +#ifdef DL_KERNELS if constexpr(NumDimSpatial == 1) { if constexpr(is_same_v && is_same_v && @@ -621,10 +77,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -632,10 +85,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -644,19 +94,14 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( op_ptrs); -#endif - add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( - op_ptrs); } #endif } if constexpr(is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && @@ -682,7 +127,6 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( op_ptrs); -#endif - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -709,12 +149,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( op_ptrs); -#endif - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -723,12 +159,8 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( op_ptrs); -#endif - add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( - op_ptrs); } #endif } @@ -740,12 +172,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); -#endif - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -753,12 +181,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); -#endif - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -767,12 +191,8 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( op_ptrs); -#endif - add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( - op_ptrs); } #endif } @@ -787,12 +207,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( op_ptrs); -#endif - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - op_ptrs); } #endif #ifdef CK_ENABLE_FP16 @@ -800,15 +216,39 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( op_ptrs); + } #endif - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); } #endif @@ -818,40 +258,125 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( op_ptrs); + } #endif - add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + } + } +#endif // DL_KERNELS +#ifdef CK_USE_XDL + if constexpr(NumDimSpatial == 1) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( op_ptrs); } #endif -#ifdef CK_ENABLE_INT8 - else if constexpr(is_same_v && is_same_v && - is_same_v && - is_same_v && - is_same_v) + } + } + if constexpr(NumDimSpatial == 2) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( op_ptrs); } #endif } - if constexpr(is_same_v && is_same_v && - is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v) { #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( op_ptrs); + } #endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + op_ptrs); + } +#endif + } + } + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( op_ptrs); } #endif @@ -860,15 +385,39 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( op_ptrs); + } #endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + } +#endif + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); } #endif @@ -878,11 +427,36 @@ struct DeviceOperationInstanceFactory && is_same_v) { -#ifdef DL_KERNELS - add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( op_ptrs); + } #endif - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + op_ptrs); + } +#endif + } + } +#endif +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( op_ptrs); } #endif @@ -892,23 +466,42 @@ struct DeviceOperationInstanceFactory && is_same_v) { - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( op_ptrs); - add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( op_ptrs); } #endif -#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v && + is_same_v) { - add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + else if constexpr(is_same_v && is_same_v && + is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + op_ptrs); + add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( op_ptrs); } #endif } } +#endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc new file mode 100644 index 000000000..59190a13e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_dl.inc @@ -0,0 +1,243 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc new file mode 100644 index 000000000..315547ca5 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_wmma.inc @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv3d backward weight +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc new file mode 100644 index 000000000..5562d236e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// conv1d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif +// conv2d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif +// conv3d backward weight +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif +#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8 +void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index b9712542a..24a5f9a5c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -12,907 +12,20 @@ #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -#ifdef CK_ENABLE_BF16 -// grouped conv1d forward, GNWC/GKXC/GNWK -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv2d forward, GNHWC/GKYXC/GNHWK -void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -// grouped conv2d forward, NHWGC/GKYXC/NHWGK -#ifdef CK_ENABLE_BF16 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( - std::vector>>& instances); - -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF16 -// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( - std::vector>>& instances); +#ifdef DL_KERNELS +#include "grouped_convolution_forward_dl.inc" #endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP16 -void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( - std::vector>>& instances); +#ifdef CK_USE_XDL +#include "grouped_convolution_forward_xdl.inc" #endif - -#ifdef CK_ENABLE_FP8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_BF8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_FP32 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( - std::vector>>& instances); -#endif - -#ifdef CK_ENABLE_INT8 -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( - std::vector>>& instances); - -void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances( - std::vector>>& instances); - -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( - std::vector>>& instances); +#ifdef CK_USE_WMMA +#include "grouped_convolution_forward_wmma.inc" #endif -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( - std::vector>>& instances); -#endif - -#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS)) -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( - std::vector>>& instances); -#endif +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { template > op_ptrs; +#ifdef DL_KERNELS + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + } +#endif + } + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { + +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); + } +#endif + } +#endif // DL_KERNELS + +#ifdef CK_USE_XDL if constexpr(NumDimSpatial == 1 && is_same_v && is_same_v && is_same_v) { @@ -1000,35 +154,13 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); - } -#endif - -#if(defined(CK_ENABLE_FP16) && defined(DL_KERNELS)) - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } #endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1037,23 +169,11 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); - } -#endif } if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) { - #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1061,15 +181,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -1077,15 +188,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); - } -#endif - #ifdef CK_ENABLE_BF16 if constexpr(is_same_v && is_same_v && @@ -1093,16 +195,6 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) - { - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); - add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); - } #endif } @@ -1121,12 +213,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1142,11 +228,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); } #endif } @@ -1188,12 +269,6 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( - op_ptrs); - add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF16 @@ -1209,6 +284,99 @@ struct DeviceOperationInstanceFactory && is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); + } +#endif + } +#endif + +#ifdef CK_USE_WMMA + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(op_ptrs); + add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(op_ptrs); + } +#endif + } + + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + op_ptrs); + add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( @@ -1217,6 +385,7 @@ struct DeviceOperationInstanceFactory>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc new file mode 100644 index 000000000..0ea24d092 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_wmma.inc @@ -0,0 +1,480 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances( + std::vector>>& instances); + +void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc new file mode 100644 index 000000000..942674ef9 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -0,0 +1,357 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_BF16 +// grouped conv1d forward, GNWC/GKXC/GNWK +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv2d forward, GNHWC/GKYXC/GNHWK +void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( + std::vector>>& instances); +#endif + +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances( + std::vector>>& instances); + +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector>>& instances); +#endif + +#ifdef CK_ENABLE_BF8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances); +#endif + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 0a12e1c49..c035e7e56 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -36,12 +36,27 @@ function(add_instance_library INSTANCE_NAME) endif() endforeach() endif() + # Do not build DL instances if DL_KERNELS macro is not set foreach(source IN LISTS ARGN) if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") message("removing dl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + # Do not build XDL instances if gfx9 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + message("removing xdl instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + # Do not build WMMA instances if gfx11 targets are not on the target list + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma") + message("removing wmma instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_library(${INSTANCE_NAME} OBJECT ${ARGN}) @@ -124,6 +139,26 @@ FOREACH(subdir_path ${dir_list}) message("Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl instances, but gfx9 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11")) + message("Found only wmma instances, but gfx11 is not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_DL_KERNELS") AND (NOT DEFINED DL_KERNELS) AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9")) + message("Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") + set(add_inst 0) + endif() + if(("${cmake_instance}" MATCHES "XDL_DL_WMMA_KERNELS") AND (NOT GPU_TARGETS MATCHES "gfx11") AND (NOT GPU_TARGETS MATCHES "gfx9") AND (NOT DEFINED DL_KERNELS)) + message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.") + set(add_inst 0) + endif() if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt index 69b6ddc75..1227a77a3 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(BATCHED_GEMM_INSTANCES) list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt index d0e9b265a..5c8470f7c 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_add_relu_gemm_add_instance device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt index cd9c95c06..8082a8c27 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_bias_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_bias_permute_instance device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt index 865a31e79..2aa607429 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_gemm_instance device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt index 28226faba..51bbdf1d7 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_reduce_instance device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt index 6244477e1..e43eb07fb 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_batched_gemm_softmax_gemm_instance device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3fd4e0344..f1fb0646e 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES) list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 87a6bbba4..a28c6717d 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index a0918d9d3..b91de832e 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONTRACTION_SCALE_INSTANCES) # FP32 diff --git a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt index 75a367076..796a9b240 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv1d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv1d_bwd_data_instance device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instance.cpp device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt index 49dfc01fd..2da515511 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_BWD_DATA_INSTANCES) list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index ba0ca3251..04b313d07 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(DEVICE_CONV2D_FWD_INSTANCES) list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt index 670cd94fc..4304d8996 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt index 68d5f582f..40a6b1ff0 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv2d_fwd_bias_relu_add_instance device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt index db92208fd..ec4a8a286 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_conv3d_bwd_data_instance device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt index fe85bb7ea..298da1fbe 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_instance device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt index bbf81a5fa..04ae90bc5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_add_fastgelu_instance device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt index 63b4a00c9..45d6abce0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_fastgelu_instance device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt index eb9345cba..d859078ca 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_multiply/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_multiply_instance device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt index 969361de9..043bdab00 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_instance device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt index 97693a256..b9aeb6a6d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_relu_add_layernorm_instance device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt index c10d4773a..e6ca64cdc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_add_silu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_add_silu_instance device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt index ccada3a85..f29943d93 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bias_add_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_bias_add_reduce_instance device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt index 426edeed7..61892e708 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_gemm_bilinear_instance device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt index 17d27ab15..2f45401ec 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_fastgelu_instance device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt index 6cbd7528e..aba9806a7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_multiply_add/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_MULTIPLY_ADD_INSTANCES) list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt index 2b2cf8c77..7ee3efe7f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_reduce_instance device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index 059b6a720..dac86d770 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GEMM_SPLITK_INSTANCES) list(APPEND GEMM_SPLITK_INSTANCES diff --git a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt index 8dd0112a6..c854b16ee 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_streamk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_gemm_streamk_instance # device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp # device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt index cfd829f87..ab4313d89 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV1D_BWD_WEIGHT xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt index f51a484bb..ca4ea515b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv1d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_conv1d_fwd_instance xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp xdl/device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 93d5bd742..ad430340e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS add_instance_library( device_grouped_conv2d_bwd_data_instance xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt index 8a896b06c..340ddfb3f 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(GROUPED_CONV2D_BWD_WEIGHT xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 2715a8cf2..1d3c3747d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS add_instance_library(device_grouped_conv2d_fwd_instance #xdl # GNHWC, GKYXC, GNHWK diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt index 836e671bf..29fa8fa3c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_BWD_DATA xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt index e1cb97529..ae6dcb988 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt index b7901a281..fa48f0edc 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_BWD_DATA_BILINEAR xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index 968e8dea2..8b89dcf7e 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -1,3 +1,4 @@ +# XDL_DL_WMMA_KERNELS set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 3825b92af..972fb5403 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt index 49706588d..436c37fd5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt index 45d270d55..f36d55d36 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_BILINEAR xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt index 08fb23afc..107624944 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_ab/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_SCALEADD_AB xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_ab_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt index ae89caaee..1be1db7d1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scaleadd_scaleadd_relu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_CONV3D_FWD_scaleadd_scaleadd_RELU xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_scaleadd_scaleadd_relu_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index de7537af4..2625e6cbe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt index ef8a440c1..167dfa9a6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_bias/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_bias_instance device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt index 648f2146c..8e9693e69 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fastgelu/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS add_instance_library(device_grouped_gemm_fastgelu_instance device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index ac22543be..1ee3d0add 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index c22a6e9e9..5d50902be 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -1,3 +1,4 @@ +# ONLY XDL_AND_DL_KERNELS set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp) set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp) set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp) diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 11ae28516..cb6ffbec6 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -2,19 +2,6 @@ set(PROFILER_SOURCES profiler.cpp profile_gemm.cpp - profile_gemm_splitk.cpp - profile_gemm_bias_add_reduce.cpp - profile_gemm_add_multiply.cpp - profile_gemm_multiply_add.cpp - profile_gemm_reduce.cpp - profile_batched_gemm.cpp - profile_batched_gemm_reduce.cpp - profile_conv_fwd.cpp - profile_conv_fwd_bias_relu.cpp - profile_conv_fwd_bias_relu_add.cpp - profile_conv_bwd_data.cpp - profile_grouped_conv_fwd.cpp - profile_grouped_conv_bwd_weight.cpp profile_reduce.cpp profile_groupnorm_bwd_data.cpp profile_groupnorm_fwd.cpp @@ -29,36 +16,57 @@ set(PROFILER_SOURCES profile_batchnorm_fwd.cpp profile_batchnorm_bwd.cpp profile_batchnorm_infer.cpp - profile_grouped_conv_bwd_data.cpp profile_conv_tensor_rearrange.cpp profile_transpose.cpp profile_permute_scale.cpp ) -if(DL_KERNELS) - list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) + list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_batched_gemm_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) + list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_conv_fwd.cpp) + endif() -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_batched_gemm_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_streamk.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_add_fastgelu.cpp) - list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) - list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) - list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) +if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp) + endif() + list(APPEND PROFILER_SOURCES profile_grouped_conv_fwd.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) endif() -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp) - list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp) +if(DL_KERNELS) + list(APPEND PROFILER_SOURCES profile_batched_gemm_multi_d.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_conv_bwd_weight.cpp) endif() set(PROFILER_EXECUTABLE ckProfiler) @@ -68,25 +76,6 @@ target_compile_options(${PROFILER_EXECUTABLE} PRIVATE -Wno-global-constructors) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE utility getopt::getopt) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_bwd_gamma_beta_instance) @@ -96,39 +85,65 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_pool3d_fwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_avg_pool3d_bwd_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_max_pool_bwd_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) -target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) -if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) +if(GPU_TARGETS MATCHES "gfx9") + if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) + endif() + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu_add_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv1d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) endif() - +if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11") + if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) + endif() + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_data_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_fwd_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) +endif() if(DL_KERNELS) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_multi_d_instance) -endif() - -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_silu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_relu_add_layernorm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_add_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_streamk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_fastgelu_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) - target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv1d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv2d_bwd_weight_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_weight_instance) endif() rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a0f90256c..720ab468e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,7 +46,18 @@ function(add_test_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() - + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) @@ -100,6 +111,18 @@ function(add_gtest_executable TEST_NAME) list(REMOVE_ITEM ARGN "${source}") endif() endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") + message("removing xdl test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() + foreach(source IN LISTS ARGN) + if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") + message("removing wmma test ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + endforeach() #only continue if there are some source files left on the list if(ARGN) add_executable(${TEST_NAME} ${ARGN}) diff --git a/test/batched_gemm/CMakeLists.txt b/test/batched_gemm/CMakeLists.txt index 9482821b6..759cf3da6 100644 --- a/test/batched_gemm/CMakeLists.txt +++ b/test/batched_gemm/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_batched_gemm test_batched_gemm.cpp) +add_gtest_executable(test_batched_gemm test_batched_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_batched_gemm PRIVATE utility device_batched_gemm_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/batched_gemm/test_batched_gemm.cpp b/test/batched_gemm/test_batched_gemm_xdl.cpp similarity index 100% rename from test/batched_gemm/test_batched_gemm.cpp rename to test/batched_gemm/test_batched_gemm_xdl.cpp diff --git a/test/batched_gemm_gemm/CMakeLists.txt b/test/batched_gemm_gemm/CMakeLists.txt index 03f1d3a4e..2b3288ef9 100644 --- a/test/batched_gemm_gemm/CMakeLists.txt +++ b/test/batched_gemm_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_gemm) - add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) - add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_gemm) + target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance) + add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16) +endif() diff --git a/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp b/test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp rename to test/batched_gemm_gemm/test_batched_gemm_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt index 32c6ee85d..c5868e4d7 100644 --- a/test/batched_gemm_reduce/CMakeLists.txt +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -1,11 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) - set(target 1) - endif() +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance) endif() -endforeach() diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp rename to test/batched_gemm_reduce/batched_gemm_reduce_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt index c011a6a3c..c042d7e00 100644 --- a/test/batched_gemm_softmax_gemm/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm) - add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) - add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) - set(target 1) - endif() - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_batched_gemm_softmax_gemm) + target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) + add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) +endif() diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp rename to test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt index 3164863ee..2e0907354 100644 --- a/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt +++ b/test/batched_gemm_softmax_gemm_permute/CMakeLists.txt @@ -1,29 +1,21 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_batched_gemm_softmax_gemm_permute) - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) - endif() - - add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) - endif() - add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp) - if(result EQUAL 0) - target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) - add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) - endif() - set(target 1) - endif() -endforeach() \ No newline at end of file +add_custom_target(test_batched_gemm_softmax_gemm_permute) +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16) +endif() +add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16) +endif() +add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance) + add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16) +endif() diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16_xdl.cpp diff --git a/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp b/test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp similarity index 100% rename from test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp rename to test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16_xdl.cpp diff --git a/test/contraction/CMakeLists.txt b/test/contraction/CMakeLists.txt index a86e72fdd..3ba0d82f0 100644 --- a/test/contraction/CMakeLists.txt +++ b/test/contraction/CMakeLists.txt @@ -1,13 +1,10 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) - add_gtest_executable(test_contraction test_contraction.cpp) - target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - add_gtest_executable(test_contraction_interface test_contraction_interface.cpp) - target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) - set(target 1) - endif() +if((DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64") OR NOT DEFINED DTYPES) + add_gtest_executable(test_contraction test_contraction_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) endif() -endforeach() + add_gtest_executable(test_contraction_interface test_contraction_interface_xdl.cpp) + if(result EQUAL 0) + target_link_libraries(test_contraction_interface PRIVATE utility device_contraction_bilinear_instance device_contraction_scale_instance) + endif() +endif() diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface_xdl.cpp similarity index 100% rename from test/contraction/test_contraction_interface.cpp rename to test/contraction/test_contraction_interface_xdl.cpp diff --git a/test/contraction/test_contraction.cpp b/test/contraction/test_contraction_xdl.cpp similarity index 100% rename from test/contraction/test_contraction.cpp rename to test/contraction/test_contraction_xdl.cpp diff --git a/test/convnd_bwd_data/CMakeLists.txt b/test/convnd_bwd_data/CMakeLists.txt index f734b46f5..e68a9b243 100644 --- a/test/convnd_bwd_data/CMakeLists.txt +++ b/test/convnd_bwd_data/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_bwd_data convnd_bwd_data.cpp) +add_gtest_executable(test_convnd_bwd_data convnd_bwd_data_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_bwd_data PRIVATE utility device_conv1d_bwd_data_instance device_conv2d_bwd_data_instance device_conv3d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp similarity index 100% rename from test/convnd_bwd_data/convnd_bwd_data.cpp rename to test/convnd_bwd_data/convnd_bwd_data_xdl.cpp diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 745aceffc..ba6d16a0d 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_convnd_fwd convnd_fwd.cpp) +add_gtest_executable(test_convnd_fwd convnd_fwd_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_convnd_fwd PRIVATE utility device_conv2d_fwd_instance) - set(target 1) - endif() -endforeach() +endif() diff --git a/test/convnd_fwd/convnd_fwd.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp similarity index 100% rename from test/convnd_fwd/convnd_fwd.cpp rename to test/convnd_fwd/convnd_fwd_xdl.cpp diff --git a/test/gemm_add/CMakeLists.txt b/test/gemm_add/CMakeLists.txt index 7df3f90ab..ab4c78184 100644 --- a/test/gemm_add/CMakeLists.txt +++ b/test/gemm_add/CMakeLists.txt @@ -1,11 +1,19 @@ -add_gtest_executable(test_gemm_add test_gemm_add.hpp) -target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +add_gtest_executable(test_gemm_add test_gemm_add_xdl.hpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add PRIVATE utility device_gemm_add_instance) +endif() -add_gtest_executable(test_gemm_add_relu test_gemm_add_relu.cpp) -target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +add_gtest_executable(test_gemm_add_relu test_gemm_add_relu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_relu PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance) +endif() -add_gtest_executable(test_gemm_add_silu test_gemm_add_silu.cpp) -target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +add_gtest_executable(test_gemm_add_silu test_gemm_add_silu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_silu PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance) +endif() -add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu.cpp) -target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +add_gtest_executable(test_gemm_add_fastgelu test_gemm_add_fastgelu_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_add_fastgelu PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance) +endif() diff --git a/test/gemm_add/test_gemm_add_fastgelu.cpp b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_fastgelu.cpp rename to test/gemm_add/test_gemm_add_fastgelu_xdl.cpp index c1c55140a..1b12ab752 100644 --- a/test/gemm_add/test_gemm_add_fastgelu.cpp +++ b/test/gemm_add/test_gemm_add_fastgelu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_fastgelu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddFastgelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_relu.cpp b/test/gemm_add/test_gemm_add_relu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_relu.cpp rename to test/gemm_add/test_gemm_add_relu_xdl.cpp index ba6aab36b..e8b769b1c 100644 --- a/test/gemm_add/test_gemm_add_relu.cpp +++ b/test/gemm_add/test_gemm_add_relu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_relu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddRelu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add_silu.cpp b/test/gemm_add/test_gemm_add_silu_xdl.cpp similarity index 98% rename from test/gemm_add/test_gemm_add_silu.cpp rename to test/gemm_add/test_gemm_add_silu_xdl.cpp index d4dd6fa38..75fa59a8e 100644 --- a/test/gemm_add/test_gemm_add_silu.cpp +++ b/test/gemm_add/test_gemm_add_silu_xdl.cpp @@ -4,7 +4,7 @@ #include "gtest/gtest.h" #include "ck/ck.hpp" #include "profiler/profile_gemm_add_silu_impl.hpp" -#include "test_gemm_add.hpp" +#include "test_gemm_add_xdl.hpp" template class TestGemmAddSilu : public TestGemmAdd diff --git a/test/gemm_add/test_gemm_add.hpp b/test/gemm_add/test_gemm_add_xdl.hpp similarity index 100% rename from test/gemm_add/test_gemm_add.hpp rename to test/gemm_add/test_gemm_add_xdl.hpp diff --git a/test/gemm_layernorm/CMakeLists.txt b/test/gemm_layernorm/CMakeLists.txt index bfc4404bd..d1102a561 100644 --- a/test/gemm_layernorm/CMakeLists.txt +++ b/test/gemm_layernorm/CMakeLists.txt @@ -1,13 +1,6 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_gemm_layernorm) - add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp) - if(result EQUAL 0) - target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) - add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) - set(target 1) - endif() - endif() -endforeach() +add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16_xdl.cpp) +if(result EQUAL 0) + add_custom_target(test_gemm_layernorm) + target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance) + add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16) +endif() diff --git a/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp b/test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp similarity index 100% rename from test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16.cpp rename to test/gemm_layernorm/test_gemm_add_relu_add_layernorm_fp16_xdl.cpp diff --git a/test/gemm_reduce/CMakeLists.txt b/test/gemm_reduce/CMakeLists.txt index 42a53c304..121ecde60 100644 --- a/test/gemm_reduce/CMakeLists.txt +++ b/test/gemm_reduce/CMakeLists.txt @@ -1,4 +1,4 @@ -add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp) +add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16_xdl.cpp) if(result EQUAL 0) target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance) endif() \ No newline at end of file diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16_xdl.cpp similarity index 100% rename from test/gemm_reduce/gemm_reduce_fp16.cpp rename to test/gemm_reduce/gemm_reduce_fp16_xdl.cpp diff --git a/test/gemm_split_k/CMakeLists.txt b/test/gemm_split_k/CMakeLists.txt index caf30fca5..4b66dddef 100644 --- a/test/gemm_split_k/CMakeLists.txt +++ b/test/gemm_split_k/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_gemm_splitk test_gemm_splitk.cpp) +add_gtest_executable(test_gemm_splitk test_gemm_splitk_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_gemm_splitk PRIVATE utility device_gemm_splitk_instance) - set(target 1) endif() -endforeach() diff --git a/test/gemm_split_k/test_gemm_splitk.cpp b/test/gemm_split_k/test_gemm_splitk_xdl.cpp similarity index 100% rename from test/gemm_split_k/test_gemm_splitk.cpp rename to test/gemm_split_k/test_gemm_splitk_xdl.cpp diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 305c568ee..3507989ba 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,19 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp) - target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) - add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) - target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) - set(target 1) - endif() -endforeach() \ No newline at end of file +add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() +add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance) +endif() diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp rename to test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_xdl_wmma.cpp diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index d7d6f8a3d..34cdc63cd 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,20 +1,12 @@ -list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) - -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) +add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight_xdl_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() - if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0) - add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp) - target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance) - add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +endif() +add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp) +if(result EQUAL 0) target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility) - set(target 1) - endif() -endforeach() \ No newline at end of file +endif() diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp rename to test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 1ce878d5c..4f245d63c 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -1,8 +1,14 @@ -add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd.cpp) -target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd test_grouped_convnd_fwd_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd PRIVATE utility device_grouped_conv1d_fwd_instance device_grouped_conv2d_fwd_instance device_grouped_conv3d_fwd_instance) +endif() add_gtest_executable(test_grouped_convnd_fwd_multi_ab_interface test_grouped_convnd_fwd_multi_ab_interface.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_ab_interface PRIVATE utility) +endif() -add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp) -target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +add_gtest_executable(test_grouped_convnd_fwd_multi_d_interface_compatibility test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_convnd_fwd_multi_d_interface_compatibility PRIVATE utility device_grouped_conv3d_fwd_instance) +endif() diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_multi_d_interface_compatibility_xdl_wmma.cpp diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp similarity index 100% rename from test/grouped_convnd_fwd/test_grouped_convnd_fwd.cpp rename to test/grouped_convnd_fwd/test_grouped_convnd_fwd_xdl_wmma.cpp diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 8c57b667e..f47685cf9 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -1,14 +1,13 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(test_grouped_gemm) - add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) - add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface.cpp) - target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) - target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) - - add_dependencies(test_grouped_gemm test_grouped_gemm_splitk test_grouped_gemm_interface) - set(target 1) - endif() -endforeach() +add_custom_target(test_grouped_gemm) + +add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) +endif() + +add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_interface) +endif() diff --git a/test/grouped_gemm/test_grouped_gemm_interface.cpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_interface.cpp rename to test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp diff --git a/test/grouped_gemm/test_grouped_gemm_splitk.cpp b/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp similarity index 100% rename from test/grouped_gemm/test_grouped_gemm_splitk.cpp rename to test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp diff --git a/test/normalization_bwd_data/CMakeLists.txt b/test/normalization_bwd_data/CMakeLists.txt index 1b6decfed..65f33da74 100644 --- a/test/normalization_bwd_data/CMakeLists.txt +++ b/test/normalization_bwd_data/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_data) add_gtest_executable(test_layernorm2d_bwd_data_fp32 test_layernorm2d_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_layernorm2d_bwd_data_fp32) add_gtest_executable(test_groupnorm_bwd_data_fp32 test_groupnorm_bwd_data_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) - add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_data_fp32 PRIVATE utility device_normalization_bwd_data_instance) +add_dependencies(test_normalization_bwd_data test_groupnorm_bwd_data_fp32) diff --git a/test/normalization_bwd_gamma_beta/CMakeLists.txt b/test/normalization_bwd_gamma_beta/CMakeLists.txt index f3579aad0..afb78dc58 100644 --- a/test/normalization_bwd_gamma_beta/CMakeLists.txt +++ b/test/normalization_bwd_gamma_beta/CMakeLists.txt @@ -1,13 +1,8 @@ add_custom_target(test_normalization_bwd_gamma_beta) add_gtest_executable(test_layernorm2d_bwd_gamma_beta_fp32 test_layernorm2d_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) -endif() +target_link_libraries(test_layernorm2d_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_layernorm2d_bwd_gamma_beta_fp32) add_gtest_executable(test_groupnorm_bwd_gamma_beta_fp32 test_groupnorm_bwd_gamma_beta_fp32.cpp) -if(result EQUAL 0) - target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) - add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) -endif() - +target_link_libraries(test_groupnorm_bwd_gamma_beta_fp32 PRIVATE utility device_normalization_bwd_gamma_beta_instance) +add_dependencies(test_normalization_bwd_gamma_beta test_groupnorm_bwd_gamma_beta_fp32) diff --git a/test/permute_scale/CMakeLists.txt b/test/permute_scale/CMakeLists.txt index be6aaf94a..d63cb7991 100644 --- a/test/permute_scale/CMakeLists.txt +++ b/test/permute_scale/CMakeLists.txt @@ -1,6 +1,4 @@ add_custom_target(test_permute) add_gtest_executable(test_permute_scale test_permute_scale.cpp) -if(result EQUAL 0) - target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) - add_dependencies(test_permute test_permute_scale) -endif() +target_link_libraries(test_permute_scale PRIVATE utility device_permute_scale_instance) +add_dependencies(test_permute test_permute_scale) diff --git a/test/transpose/CMakeLists.txt b/test/transpose/CMakeLists.txt index 530cc9d72..fb9379bea 100644 --- a/test/transpose/CMakeLists.txt +++ b/test/transpose/CMakeLists.txt @@ -1,9 +1,4 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_gtest_executable(test_transpose test_transpose.cpp) - target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) - set(target 1) - endif() -endforeach() +add_gtest_executable(test_transpose test_transpose_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_transpose PRIVATE utility device_transpose_instance) +endif() diff --git a/test/transpose/test_transpose.cpp b/test/transpose/test_transpose_xdl.cpp similarity index 100% rename from test/transpose/test_transpose.cpp rename to test/transpose/test_transpose_xdl.cpp diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index 383707828..1eb6c35db 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -12,10 +12,8 @@ add_dependencies(test_wrapper test_wrapper_copy) add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp) target_link_libraries(test_wrapper_partition PRIVATE utility) add_dependencies(test_wrapper test_wrapper_partition) -if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR - GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR - GPU_TARGETS MATCHES "gfx942") - add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp) +add_gtest_executable(test_wrapper_gemm test_wrapper_gemm_xdl.cpp) +if(result EQUAL 0) target_link_libraries(test_wrapper_gemm PRIVATE utility) add_dependencies(test_wrapper test_wrapper_gemm) endif() diff --git a/test/wrapper/test_wrapper_gemm.cpp b/test/wrapper/test_wrapper_gemm_xdl.cpp similarity index 100% rename from test/wrapper/test_wrapper_gemm.cpp rename to test/wrapper/test_wrapper_gemm_xdl.cpp -- GitLab From 9a194837af0e0d71399d751d9a30f5b6ee4843ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 3 Apr 2024 00:23:49 +0200 Subject: [PATCH 16/63] Introduce combined elementwise ops (#1217) * Introduce combined elementwise ops * Introduce refrence elementwise --- example/44_elementwise_permute/CMakeLists.txt | 2 + .../elementwise_binary_4D_fp16.cpp | 140 +++++ .../elementwise_permute.cpp | 67 +-- .../elementwise_permute_3d.cpp | 51 +- .../elementwise_permute_4D_fp16.cpp | 54 +- .../elementwise_permute_4D_fp16_2d.cpp | 56 +- .../elementwise_permute_4D_fp16_col.cpp | 87 ++- .../elementwise_permute_4D_fp16_row.cpp | 73 +-- .../elementwise_permute_4D_fp32_col.cpp | 85 +-- .../elementwise_permute_4D_fp32_row.cpp | 72 +-- .../elementwise_trinary_4D_fp16.cpp | 156 +++++ .../element/binary_element_wise_operation.hpp | 104 ++++ .../combined_element_wise_operation.hpp | 103 ++++ .../element/unary_element_wise_operation.hpp | 248 +++++++- ...idwise_elementwise_dynamic_vector_dims.hpp | 16 +- include/ck/utility/math_v2.hpp | 556 +++++++++++++++++- .../cpu/reference_elementwise.hpp | 110 ++++ .../profiler/profile_permute_scale_impl.hpp | 22 +- 18 files changed, 1694 insertions(+), 308 deletions(-) create mode 100644 example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp create mode 100644 example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp diff --git a/example/44_elementwise_permute/CMakeLists.txt b/example/44_elementwise_permute/CMakeLists.txt index a963399dc..3cf481250 100644 --- a/example/44_elementwise_permute/CMakeLists.txt +++ b/example/44_elementwise_permute/CMakeLists.txt @@ -4,6 +4,8 @@ add_example_executable(example_elementwise_permute_4D_fp32_row elementwise_permu add_example_executable(example_elementwise_permute_4D_fp16_row elementwise_permute_4D_fp16_row.cpp) add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permute_4D_fp32_col.cpp) add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp) +add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp) +add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp) add_example_executable(example_elementwise_permute elementwise_permute.cpp) if((NOT GPU_TARGETS MATCHES "gfx940") AND (NOT GPU_TARGETS MATCHES "gfx941") AND (NOT GPU_TARGETS MATCHES "gfx942")) add_example_executable(example_elementwise_permute_3d elementwise_permute_3d.cpp) diff --git a/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp new file mode 100644 index 000000000..8819bb65e --- /dev/null +++ b/example/44_elementwise_permute/elementwise_binary_4D_fp16.cpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 +using BinaryAddUnaryScaleSquare = ck::tensor_operation::element_wise:: + BinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + BinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 2> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<2, ADataType, BDataType, BinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + BinaryAddUnaryScaleSquare{BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); + pass &= + ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + } + + return pass ? 0 : 1; +} diff --git a/example/44_elementwise_permute/elementwise_permute.cpp b/example/44_elementwise_permute/elementwise_permute.cpp index 24e161c6d..d3c3085eb 100644 --- a/example/44_elementwise_permute/elementwise_permute.cpp +++ b/example/44_elementwise_permute/elementwise_permute.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -30,20 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -51,32 +39,7 @@ int main() std::vector ncdhw = {16, 8, 8, 8, 8}; std::vector ndhwc = {16, 8, 8, 8, 8}; - Tensor a(ncdhw); - Tensor b(ndhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - /**std::array a_strides = { - static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[2] * ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[3] * ncdhw[4]), - static_cast(ncdhw[4]), - 1}; - std::array b_strides = { - static_cast(ndhwc[1] * ndhwc[2] * ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[2] * ndhwc[3] * ndhwc[4]), - 1, - static_cast(ndhwc[3] * ndhwc[4]), - static_cast(ndhwc[4])};**/ std::array a_strides = { static_cast(ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]), @@ -93,6 +56,20 @@ int main() 1}; ck::ranges::copy(ncdhw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -126,10 +103,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_3d.cpp b/example/44_elementwise_permute/elementwise_permute_3d.cpp index f3aca57c3..47d8c4de6 100644 --- a/example/44_elementwise_permute/elementwise_permute_3d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_3d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_3d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -34,20 +36,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<4>, // InScalarPerVectorSeq ck::Sequence<4>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_ndhwc, const HostTensorA& A_ncdhw, Functor functor) -{ - for(std::size_t n = 0; n < A_ncdhw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_ncdhw.mDesc.GetLengths()[1]; ++c) - for(std::size_t d = 0; d < A_ncdhw.mDesc.GetLengths()[2]; ++d) - for(std::size_t h = 0; h < A_ncdhw.mDesc.GetLengths()[3]; ++h) - for(std::size_t w = 0; w < A_ncdhw.mDesc.GetLengths()[4]; ++w) - { - auto a_val = A_ncdhw(n, c, d, h, w); - functor(B_ndhwc(n, d, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -59,10 +47,13 @@ int main() const int W = 5; const int D = 16; - std::vector ncdhw = {N, C, D, H, W}; - std::vector ndhwc = {N, D, H, W, C}; - Tensor a(ncdhw); - Tensor b(ndhwc); + std::array ab_lengths{N, C, H, W, D}; + std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W + std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -74,10 +65,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, C, H, W, D}; - std::array a_strides = {C * D * H * W, H * W, W, 1, D * H * W}; // N, C, D, H, W - std::array b_strides = {C * H * W * D, H * W * D, W * D, D, 1}; // N, D, H, W, C - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,11 +81,12 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]; + std::size_t flop = std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * + ab_lengths[3] * ab_lengths[4]; std::size_t num_btype = - sizeof(ADataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]) + - sizeof(BDataType) * (ncdhw[0] * ncdhw[1] * ncdhw[2] * ncdhw[3] * ncdhw[4]); + (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3] * ab_lengths[4]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -111,10 +99,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(ndhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 1b28a901c..3ea1aa4bf 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -35,19 +37,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +44,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -77,9 +54,22 @@ int main() 1, static_cast(nhwc[2] * nhwc[3]), static_cast(nhwc[3])}; - ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -111,10 +101,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp index 30231a375..1747e6dd8 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_2d.cpp @@ -8,6 +8,8 @@ #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_2d_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -30,22 +32,6 @@ using DeviceElementwisePermuteInstance = ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - const std::vector& shape_nchw, - Functor functor) -{ - for(std::size_t n = 0; n < shape_nchw[0]; ++n) - for(std::size_t c = 0; c < shape_nchw[1]; ++c) - for(std::size_t h = 0; h < shape_nchw[2]; ++h) - for(std::size_t w = 0; w < shape_nchw[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -54,13 +40,16 @@ int main() const int N = 120; const int C = 128; const int H = 32; - const int W = 1024; + const int W = 32; - std::vector nchw = {N, C, H, W}; - std::vector nhwc = {N, H, W, C}; + std::array ab_lengths{N, H, W, C}; + + std::array a_strides = {C * H * W, W, 1, H * W}; + std::array b_strides = {H * W * C, W * C, C, 1}; - Tensor a(nchw); - Tensor b(nhwc); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -72,11 +61,6 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths{N, H, W, C}; - - std::array a_strides = {C * H * W, W, 1, H * W}; - std::array b_strides = {H * W * C, W * C, C, 1}; - auto broadcastPermute = DeviceElementwisePermuteInstance{}; auto argument = broadcastPermute.MakeArgumentPointer( ab_lengths, {a_strides}, {b_strides}, input, output, PassThrough{}); @@ -94,10 +78,11 @@ int main() float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t flop = + std::size_t(2) * ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]; - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t num_btype = (sizeof(ADataType) + sizeof(BDataType)) * + (ab_lengths[0] * ab_lengths[1] * ab_lengths[2] * ab_lengths[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -110,11 +95,16 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, PassThrough>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); - Tensor host_b(nhwc); - host_elementwise4D, Tensor, PassThrough>( - host_b, a, nchw, PassThrough{}); + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, PassThrough{}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index f832601f0..13c67fce0 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -6,9 +6,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -21,11 +23,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -36,23 +41,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; @@ -60,8 +48,21 @@ int main() std::vector nchw = {16, 8, 32, 64}; std::vector nhwc = {16, 32, 64, 8}; - Tensor a(nchw); - Tensor b(nhwc); + std::array ab_lengths; + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; std::mt19937 gen(11939); @@ -84,22 +85,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -113,11 +106,10 @@ int main() auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); float ave_time = broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; - - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + std::size_t num_btype = + (2 * sizeof(ADataType) + sizeof(BDataType)) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -129,10 +121,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index bae85f53c..0a0f6fec1 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,29 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +112,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index fe7acd301..fc664186b 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,32 +40,29 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<1>, // InScalarPerVectorSeq ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); - } -} - int main() { bool do_verification = true; bool time_kernel = true; - std::vector nchw = {5, 4, 2, 3}; - std::vector nhwc = {5, 2, 3, 4}; - Tensor a(nchw); - Tensor b(nhwc); + std::vector nchw = {16, 8, 32, 64}; + std::vector nhwc = {16, 32, 64, 8}; + std::array ab_lengths; + + std::array a_strides = {1, + static_cast(nchw[0]), + static_cast(nchw[0] * nchw[1]), + static_cast(nchw[0] * nchw[1] * nchw[2])}; + + std::array b_strides = {1, + static_cast(nhwc[0] * nhwc[1] * nhwc[2]), + static_cast(nhwc[0]), + static_cast(nhwc[0] * nhwc[1])}; + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); float scale = 1.f; auto i = 0; @@ -84,22 +86,14 @@ int main() std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -129,10 +123,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index aebdb37d9..a0c416318 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -5,9 +5,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -20,11 +22,14 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using UnaryOp = ck::tensor_operation::element_wise::Scale; +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< ck::Tuple, // InDataTypeTuple ck::Tuple, // OutDataTypeTuple - UnaryOp, // UnaryOp + UnaryScaleSquare, // UnaryScaleSquare 4, // NumDim 256, // BlockSize 128, // M0PerBlock @@ -35,19 +40,6 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ck::Sequence<8>, // InScalarPerVectorSeq ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) -{ - for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) - for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) - for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) - for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) - { - auto a_val = A_nchw(n, c, h, w); - functor(B_nhwc(n, h, w, c), a_val); - } -} - int main() { bool do_verification = true; @@ -55,18 +47,6 @@ int main() std::vector nchw = {16, 128, 32, 64}; std::vector nhwc = {16, 32, 64, 128}; - Tensor a(nchw); - Tensor b(nhwc); - float scale = 2.f; - a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; std::array ab_lengths; std::array a_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), @@ -80,9 +60,28 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); + std::array, 1> as = {Tensor(ab_lengths, a_strides)}; + Tensor& a = as[0]; + Tensor b(ab_lengths, b_strides); + float scale = 2.f; + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer( - ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); + auto argument = + broadcastPermute.MakeArgumentPointer(ab_lengths, + {a_strides}, + {b_strides}, + input, + output, + UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -112,10 +111,17 @@ int main() if(do_verification) { - b_device_buf.FromDevice(b.mData.data()); - Tensor host_b(nhwc); - host_elementwise4D(host_b, a, UnaryOp{scale}); + Tensor host_b(ab_lengths, b_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<1, ADataType, BDataType, UnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, host_b, UnaryScaleSquare{UnarySquare{}, UnaryScale{scale}}); + ref_invoker.Run(ref_argument); + b_device_buf.FromDevice(b.mData.data()); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); } diff --git a/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp new file mode 100644 index 000000000..050300eed --- /dev/null +++ b/example/44_elementwise_permute/elementwise_trinary_4D_fp16.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" + +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" + +using F16 = ck::half_t; +using F32 = float; + +using ADataType = F16; +using BDataType = F16; + +using UnaryScale = ck::tensor_operation::element_wise::Scale; +using UnarySquare = ck::tensor_operation::element_wise::UnarySquare; +using UnaryScaleSquare = + ck::tensor_operation::element_wise::UnaryCombinedOp; +using BinaryAdd = ck::tensor_operation::element_wise::Add; +// B = alpha * A0 * A0 + beta * A1 * A1 + gamma * A2 * A2 +using TrinaryAddUnaryScaleSquare = + ck::tensor_operation::element_wise::TrinaryWithUnaryCombinedOp; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + TrinaryAddUnaryScaleSquare, // ElementwiseOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8, 8, 8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq + +int main() +{ + bool do_verification = true; + bool time_kernel = true; + + std::vector nchw = {16, 128, 32, 64}; + std::array ab_lengths; + std::array ab_strides = {static_cast(nchw[1] * nchw[2] * nchw[3]), + static_cast(nchw[2] * nchw[3]), + static_cast(nchw[3]), + 1}; + + ck::ranges::copy(nchw, ab_lengths.begin()); + + std::array, 3> as = {Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides), + Tensor(ab_lengths, ab_strides)}; + Tensor& a0 = as[0]; + Tensor& a1 = as[1]; + Tensor& a2 = as[2]; + Tensor b(ab_lengths, ab_strides); + float alpha = 3.f; + float beta = 2.f; + float gamma = 4.f; + a0.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a2.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + + DeviceMem a0_device_buf(sizeof(ADataType) * a0.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(ADataType) * a1.mDesc.GetElementSpaceSize()); + DeviceMem a2_device_buf(sizeof(ADataType) * a2.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0.mData.data()); + a1_device_buf.ToDevice(a1.mData.data()); + a2_device_buf.ToDevice(a2.mData.data()); + + std::array inputs = {a0_device_buf.GetDeviceBuffer(), + a1_device_buf.GetDeviceBuffer(), + a2_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + + auto broadcastPermute = DeviceElementwisePermuteInstance{}; + auto unary_scale_op_a0 = UnaryScaleSquare{UnarySquare{}, UnaryScale{alpha}}; + auto unary_scale_op_a1 = UnaryScaleSquare{UnarySquare{}, UnaryScale{beta}}; + auto unary_scale_op_a2 = UnaryScaleSquare{UnarySquare{}, UnaryScale{gamma}}; + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, + {ab_strides, ab_strides, ab_strides}, + {ab_strides}, + inputs, + output, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + + if(!broadcastPermute.IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + }; + + std::cout << "A0 (nchw): " << a0.mDesc << std::endl; + std::cout << "A1 (nchw): " << a1.mDesc << std::endl; + std::cout << "A2 (nchw): " << a2.mDesc << std::endl; + std::cout << "B (nchw): " << b.mDesc << std::endl; + + auto broadcastPermute_invoker_ptr = broadcastPermute.MakeInvokerPointer(); + float ave_time = + broadcastPermute_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + std::size_t flop = std::size_t(5) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; + + std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + + sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + bool pass = true; + + if(do_verification) + { + Tensor host_b(ab_lengths, ab_strides); + using ReferenceElementwiseInstance = ck::tensor_operation::host:: + ReferenceElementwise<3, ADataType, BDataType, TrinaryAddUnaryScaleSquare>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument( + as, + host_b, + TrinaryAddUnaryScaleSquare{ + BinaryAdd{}, BinaryAdd{}, unary_scale_op_a0, unary_scale_op_a1, unary_scale_op_a2}); + ref_invoker.Run(ref_argument); + + const double threshold = std::pow(2, -10) * 2; + b_device_buf.FromDevice(b.mData.data()); + pass &= ck::utils::check_err( + b.mData, host_b.mData, "Error: Incorrect results b", threshold, threshold); + } + + return pass ? 0 : 1; +} diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index ba2e0057d..f6e57aad0 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -92,6 +92,110 @@ struct Add }; }; +struct Max +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::max(x0_converted, x1_converted); + } +}; + +struct Min +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::min(x0_converted, x1_converted); + } +}; + +struct Multiply +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 * type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 * x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 * x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 * x1; + }; +}; + struct ScaleAdd { __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} diff --git a/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp new file mode 100644 index 000000000..6d1d6b57c --- /dev/null +++ b/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// y = UnaryOp0(UnaryOp1(...(x))) +template +struct UnaryCombinedOp +{ + __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // Execute first unary op to copy data to y + unary_ops_.At(Number<0>{})(y, x); + + static_for<1, Tuple::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); }); + }; + + Tuple unary_ops_; +}; + +// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1)) +template +struct BinaryWithUnaryCombinedOp +{ + __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1) + : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result); + }; + + private: + BinaryOp binary_op_; + UnaryOp0 unary_op0_; + UnaryOp1 unary_op1_; +}; + +// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2)) +template +struct TrinaryWithUnaryCombinedOp +{ + __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, + BinaryOp0 binary_op1, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1, + UnaryOp2 unary_op2) + : binary_op0_(binary_op0), + binary_op1_(binary_op1), + unary_op0_(unary_op0), + unary_op1_(unary_op1), + unary_op2_(unary_op2) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const + { + + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + Y unary_x2_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + unary_op2_(unary_x2_tmp_result, x2); + binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result); + binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result); + }; + + private: + BinaryOp0 binary_op0_{}; + BinaryOp1 binary_op1_{}; + UnaryOp0 unary_op0_{}; + UnaryOp1 unary_op1_{}; + UnaryOp2 unary_op2_{}; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 9c64ad4df..1add81e69 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -12,10 +12,6 @@ namespace ck { namespace tensor_operation { namespace element_wise { -#if CK_WORKAROUND_SWDEV_383542 -extern "C" __device__ float __ocml_native_recip_f32(float); -#endif - struct PassThroughPack2 { template @@ -449,11 +445,7 @@ struct FastGelu const float u = x * (c1 * x * x + c2); const float emu = __expf(u); -#if !CK_WORKAROUND_SWDEV_383542 - y = x * __frcp_rn(1.f + emu); -#else - y = x * __ocml_native_recip_f32(1.f + emu); -#endif + y = x * ck::math::rcp(1.f + emu); } template <> @@ -559,6 +551,244 @@ struct TanH }; }; +struct ACos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acos(x); + }; +}; + +struct Neg +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::neg(x); + }; +}; + +struct ATan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atan(x); + }; +}; + +struct Sin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sin(x); + }; +}; + +struct ASinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asinh(x); + }; +}; + +struct Cos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cos(x); + }; +}; + +struct ACosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::acosh(x); + }; +}; + +struct Tan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::tan(x); + }; +}; + +struct ATanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::atanh(x); + }; +}; + +struct SinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::sinh(x); + }; +}; + +struct Ceil +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::ceil(x); + }; +}; + +struct Exp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::exp(x); + }; +}; + +struct CosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::cosh(x); + }; +}; + +struct Floor +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::floor(x); + }; +}; + +struct Log +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::log(x); + }; +}; + +struct ASin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::asin(x); + }; +}; + +struct Rcp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = ck::math::rcp(x); + }; +}; + struct Swish { Swish(float beta = 1.0f) : beta_(beta) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp index 2a906a143..4d1a09b44 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp @@ -118,8 +118,16 @@ struct GridwiseElementwise __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); const index_t m1_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); - const auto thread_grid_offset = - make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + const auto input_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); + const auto output_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); using ThisThreadBlock = ThisThreadBlock; // If src and dst have same vector dim, then: @@ -157,9 +165,9 @@ struct GridwiseElementwise uniform_sequence_gen_t, uniform_sequence_gen_t, uniform_sequence_gen_t>{in_grid_desc_tuple, - thread_grid_offset, + input_thread_grid_offset, out_grid_desc_tuple, - thread_grid_offset, + output_thread_grid_offset, elementwise_op}; global_to_global_transfer.Run( in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); diff --git a/include/ck/utility/math_v2.hpp b/include/ck/utility/math_v2.hpp index a07fde3da..2b921cdc7 100644 --- a/include/ck/utility/math_v2.hpp +++ b/include/ck/utility/math_v2.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -14,6 +14,10 @@ namespace ck { namespace math { +#if CK_WORKAROUND_SWDEV_383542 +extern "C" __device__ float __ocml_native_recip_f32(float); +#endif + // math functions for the host, some are implemented by calling C++ std functions static inline __host__ float abs(float x) { return std::abs(x); }; @@ -111,6 +115,276 @@ inline __host__ double tanh(double x) return std::tanh(x); }; +template +inline __host__ T acos(T x) +{ + return ck::type_convert(std::acosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acos(float x) +{ + return std::acosf(x); +}; + +template <> +inline __host__ double acos(double x) +{ + return std::acos(x); +}; + +template +inline __host__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __host__ float neg(float x) +{ + return -x; +}; + +template <> +inline __host__ double neg(double x) +{ + return -x; +}; + +template <> +inline __host__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __host__ int8_t neg(int8_t x) +{ + return -x; +}; + +template +inline __host__ T atan(T x) +{ + return ck::type_convert(std::atanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atan(float x) +{ + return std::atanf(x); +}; + +template <> +inline __host__ double atan(double x) +{ + return std::atan(x); +}; + +template +inline __host__ T sin(T x) +{ + return ck::type_convert(std::sinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sin(float x) +{ + return std::sinf(x); +}; + +template <> +inline __host__ double sin(double x) +{ + return std::sin(x); +}; + +template +inline __host__ T asin(T x) +{ + return ck::type_convert(std::asinf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asin(float x) +{ + return std::asinf(x); +}; + +template <> +inline __host__ double asin(double x) +{ + return std::asin(x); +}; + +template +inline __host__ T asinh(T x) +{ + return ck::type_convert(std::asinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float asinh(float x) +{ + return std::asinhf(x); +}; + +template <> +inline __host__ double asinh(double x) +{ + return std::asinh(x); +}; + +template +inline __host__ T cos(T x) +{ + return ck::type_convert(std::cosf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cos(float x) +{ + return std::cosf(x); +}; + +template <> +inline __host__ double cos(double x) +{ + return std::cos(x); +}; + +template +inline __host__ T acosh(T x) +{ + return ck::type_convert(std::acoshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float acosh(float x) +{ + return std::acoshf(x); +}; + +template <> +inline __host__ double acosh(double x) +{ + return std::acosh(x); +}; + +template +inline __host__ T tan(T x) +{ + return ck::type_convert(std::tanf(ck::type_convert(x))); +}; + +template <> +inline __host__ float tan(float x) +{ + return std::tanf(x); +}; + +template <> +inline __host__ double tan(double x) +{ + return std::tan(x); +}; + +template +inline __host__ T atanh(T x) +{ + return ck::type_convert(std::atanhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float atanh(float x) +{ + return std::atanhf(x); +}; + +template <> +inline __host__ double atanh(double x) +{ + return std::atanh(x); +}; + +template +inline __host__ T sinh(T x) +{ + return ck::type_convert(std::sinhf(ck::type_convert(x))); +}; + +template <> +inline __host__ float sinh(float x) +{ + return std::sinhf(x); +}; + +template <> +inline __host__ double sinh(double x) +{ + return std::sinh(x); +}; + +template +inline __host__ T ceil(T x) +{ + return ck::type_convert(std::ceilf(ck::type_convert(x))); +}; + +template <> +inline __host__ float ceil(float x) +{ + return std::ceilf(x); +}; + +template <> +inline __host__ double ceil(double x) +{ + return std::ceil(x); +}; + +template +inline __host__ T cosh(T x) +{ + return ck::type_convert(std::coshf(ck::type_convert(x))); +}; + +template <> +inline __host__ float cosh(float x) +{ + return std::coshf(x); +}; + +template <> +inline __host__ double cosh(double x) +{ + return std::cosh(x); +}; + +template +inline __host__ T floor(T x) +{ + return ck::type_convert(std::floorf(ck::type_convert(x))); +}; + +template <> +inline __host__ float floor(float x) +{ + return std::floorf(x); +}; + +template <> +inline __host__ double floor(double x) +{ + return std::floor(x); +}; + +template +inline __host__ T rcp(T x) +{ + return ck::type_convert(1.f / ck::type_convert(x)); +}; + template inline __host__ T exp(T x) { @@ -282,6 +556,286 @@ inline __device__ double tanh(double x) return ::tanh(x); }; +template +inline __device__ T acos(T x) +{ + return ck::type_convert(::acosf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acos(float x) +{ + return ::acosf(x); +}; + +template <> +inline __device__ double acos(double x) +{ + return ::acos(x); +}; + +template +inline __device__ T neg(T x) +{ + return ck::type_convert(-(ck::type_convert(x))); +}; + +template <> +inline __device__ float neg(float x) +{ + return -x; +}; + +template <> +inline __device__ double neg(double x) +{ + return -x; +}; + +template <> +inline __device__ int32_t neg(int32_t x) +{ + return -x; +}; + +template <> +inline __device__ int8_t neg(int8_t x) +{ + return -x; +}; + +template <> +inline __device__ half_t neg(half_t x) +{ + return __hneg(x); +}; + +template +inline __device__ T atan(T x) +{ + return ck::type_convert(::atanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atan(float x) +{ + return ::atanf(x); +}; + +template <> +inline __device__ double atan(double x) +{ + return ::atan(x); +}; + +template +inline __device__ T sin(T x) +{ + return ck::type_convert(::sinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sin(float x) +{ + return ::sinf(x); +}; + +template <> +inline __device__ double sin(double x) +{ + return ::sin(x); +}; + +template <> +inline __device__ half_t sin(half_t x) +{ + return ::hsin(x); +}; + +template +inline __device__ T asin(T x) +{ + return ck::type_convert(::asinf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asin(float x) +{ + return ::asinf(x); +}; + +template <> +inline __device__ double asin(double x) +{ + return ::asin(x); +}; + +template +inline __device__ T asinh(T x) +{ + return ck::type_convert(::asinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float asinh(float x) +{ + return ::asinhf(x); +}; + +template <> +inline __device__ double asinh(double x) +{ + return ::asinh(x); +}; + +template +inline __device__ T acosh(T x) +{ + return ck::type_convert(::acoshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float acosh(float x) +{ + return ::acoshf(x); +}; + +template <> +inline __device__ double acosh(double x) +{ + return ::acosh(x); +}; + +template +inline __device__ T tan(T x) +{ + return ck::type_convert(::tanf(ck::type_convert(x))); +}; + +template <> +inline __device__ float tan(float x) +{ + return ::tanf(x); +}; + +template <> +inline __device__ double tan(double x) +{ + return ::tan(x); +}; + +template +inline __device__ T atanh(T x) +{ + return ck::type_convert(::atanhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float atanh(float x) +{ + return ::atanhf(x); +}; + +template <> +inline __device__ double atanh(double x) +{ + return ::atanh(x); +}; + +template +inline __device__ T sinh(T x) +{ + return ck::type_convert(::sinhf(ck::type_convert(x))); +}; + +template <> +inline __device__ float sinh(float x) +{ + return ::sinhf(x); +}; + +template <> +inline __device__ double sinh(double x) +{ + return ::sinh(x); +}; + +template +inline __device__ T ceil(T x) +{ + return ck::type_convert(::ceilf(ck::type_convert(x))); +}; + +template <> +inline __device__ float ceil(float x) +{ + return ::ceilf(x); +}; + +template <> +inline __device__ double ceil(double x) +{ + return ::ceil(x); +}; + +template <> +inline __device__ half_t ceil(half_t x) +{ + return ::hceil(x); +}; + +template +inline __device__ T cosh(T x) +{ + return ck::type_convert(::coshf(ck::type_convert(x))); +}; + +template <> +inline __device__ float cosh(float x) +{ + return ::coshf(x); +}; + +template <> +inline __device__ double cosh(double x) +{ + return ::cosh(x); +}; + +template +inline __device__ T floor(T x) +{ + return ck::type_convert(::floorf(ck::type_convert(x))); +}; + +template <> +inline __device__ float floor(float x) +{ + return ::floorf(x); +}; + +template <> +inline __device__ double floor(double x) +{ + return ::floor(x); +}; + +template <> +inline __device__ half_t floor(half_t x) +{ + return ::hfloor(x); +}; + +template +inline __device__ T rcp(T x) +{ +#if !CK_WORKAROUND_SWDEV_383542 + return __frcp_rn(x); +#else + return __ocml_native_recip_f32(x); +#endif +}; + template inline __device__ T exp(T x) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp new file mode 100644 index 000000000..470641fff --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceElementwise : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + : a_tensors_{a_tensors}, b_tensor_{b_tensor}, element_op_{element_op} + { + } + + const std::array, NumATensors>& a_tensors_; + Tensor& b_tensor_; + ElementOp element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceElementwise::Argument; + + float Run(const Argument& arg) + { + if constexpr(NumATensors == 1) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx)); + }); + } + else if constexpr(NumATensors == 2) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), arg.a_tensors_[0](idx), arg.a_tensors_[1](idx)); + }); + } + else if constexpr(NumATensors == 3) + { + arg.b_tensor_.ForEach([&](auto& self, auto idx) { + arg.element_op_(self(idx), + arg.a_tensors_[0](idx), + arg.a_tensors_[1](idx), + arg.a_tensors_[2](idx)); + }); + } + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const std::array, NumATensors>& a_tensors, + Tensor& b_tensor, + ElementOp element_op) + { + return Argument{a_tensors, b_tensor, element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceElementwise" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp index c69e36142..186a24501 100644 --- a/profiler/include/profiler/profile_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -14,6 +14,8 @@ #include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp" + #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -21,14 +23,6 @@ #include "ck/library/utility/literals.hpp" namespace ck { -template -void reference_permute_scale(HostTensorB& b_tensor, - const HostTensorA& a_tensor, - ElementOp tensor_op) -{ - b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); }); -} - namespace profiler { template @@ -46,7 +40,8 @@ bool profile_permute_scale_impl(int do_verification, using ElementOp = ck::tensor_operation::element_wise::Scale; float scale = 2.f; - Tensor a(lengths_vector, input_strides_vector); + std::array, 1> as = {Tensor(lengths_vector, input_strides_vector)}; + Tensor& a = as[0]; Tensor b(lengths_vector, output_strides_vector); Tensor host_b(lengths_vector, output_strides_vector); @@ -83,7 +78,14 @@ bool profile_permute_scale_impl(int do_verification, if(do_verification) { - reference_permute_scale(host_b, a, ElementOp{scale}); + using ReferenceElementwiseInstance = + ck::tensor_operation::host::ReferenceElementwise<1, ADataType, BDataType, ElementOp>; + auto ref_elementwise = ReferenceElementwiseInstance{}; + auto ref_invoker = ref_elementwise.MakeInvoker(); + + auto ref_argument = ref_elementwise.MakeArgument(as, host_b, ElementOp{scale}); + + ref_invoker.Run(ref_argument); } auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; -- GitLab From a61e73bc56966a138ab1b5dadf27983800788431 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Wed, 3 Apr 2024 09:08:08 -0500 Subject: [PATCH 17/63] Add instances for conv_scale with fp8@bf8->fp8 (#1220) * Update device op api to support BComputeType * Add example * Add instances * Add profiler mode * Add client example * Update copyright year * Add BComputeType check * Fix compute types --- client_example/16_convnd_fwd/CMakeLists.txt | 5 + client_example/16_convnd_fwd/common.hpp | 8 +- .../16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp | 50 ++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + .../09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp | 83 ++++++++++++ .../device_grouped_conv_fwd_multiple_abd.hpp | 10 +- ...ped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 15 ++- ...ouped_conv_fwd_multiple_d_xdl_cshuffle.hpp | 10 +- ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 39 +++--- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 7 +- .../device_grouped_conv_fwd_xdl_instance.hpp | 36 ++++++ .../gpu/grouped_convolution_forward.hpp | 118 ++++++++++++------ .../gpu/grouped_convolution_forward_xdl.inc | 18 +++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 + ..._ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp | 54 ++++++++ .../profile_grouped_conv_fwd_impl.hpp | 10 +- profiler/src/profile_grouped_conv_fwd.cpp | 75 ++++++----- 17 files changed, 441 insertions(+), 103 deletions(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e034c468d..808693b63 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -17,6 +17,11 @@ if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) endif() +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) + target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) +endif() + if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index a5b7c5b42..ee408c744 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -95,7 +95,8 @@ template + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool run_grouped_conv_fwd(std::array in_lengths, std::array wei_lengths, std::array out_lengths) @@ -186,7 +187,8 @@ bool run_grouped_conv_fwd(std::array; + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp new file mode 100644 index 000000000..8508dc9c5 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp8_bf8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 61e9a43c3..afbe74121 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -5,6 +5,7 @@ add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp new file mode 100644 index 000000000..53a12377c --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8_bf8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using AComputeType = ck::f8_t; +using BComputeType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeType, + BComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp index fa3dcfdf2..31e8d639a 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval().IsTuple()); * \tparam AElementwiseOperation A elementwise operation. * \tparam BElementwiseOperation B elementwise operation. * \tparam CDEElementwiseOperation CDE elementwise operation. - * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + * \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed). + * \tparam BComputeType Compute data type for B tensor (default: AComputeType). */ template ::value, Number<0>, - ADataType>())> // ComputeType is InputType by default (first + ADataType>()), // AComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed + typename BComputeType = AComputeType> struct DeviceGroupedConvFwdMultipleABD : public BaseOperator { static constexpr bool isMultiA = is_detected::value; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 5ff42f98f..f53ec8a4e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -254,13 +254,14 @@ template ::value, Number<0>, ADataType>()), // ComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed - LoopScheduler LoopSched = make_default_loop_scheduler()> + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle : public DeviceGroupedConvFwdMultipleABD + AComputeDataType, + BComputeDataType> { using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; @@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmBDataType = std::conditional_t, BDataType>; #define GridwiseGemmTemplateParameters \ - GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ @@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType // Use appropriate gridwise gemm using GridwiseGemm = std::conditional_t::value, Number<0>, ADataType>()), // ComputeType is InputType by default (first // in tuple for MultiAB), unpack if tuple was // passed - LoopScheduler LoopSched = make_default_loop_scheduler()> + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, @@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, - ComputeDataType, + AComputeDataType, + BComputeDataType, LoopSched>; } // namespace device diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 4b7cc5679..0f98f9e63 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -30,7 +30,7 @@ namespace ck { // D0, D1, ... and E have the same layout template + PipelineVersion PipelineVer = PipelineVersion::v1, + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleABD_xdl_cshuffle { static constexpr index_t NumATensor = AsDataType::Size(); @@ -101,10 +102,13 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle decltype(GridwiseGemmPipeline_Selector())>; #if CK_WORKAROUND_DENORM_FIX - using ComputeDataType = - conditional_t, ck::bhalf_t, ComputeDataType_>; + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using ComputeDataType = ComputeDataType_; + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -195,8 +199,8 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle constexpr auto c_block_size = c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); - return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * - sizeof(ComputeDataType), + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), c_block_size * sizeof(CShuffleDataType)); } @@ -597,7 +601,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, AsDataType, - Tuple, + Tuple, decltype(as_grid_desc_ak0_m_ak1), decltype(tie(a_block_desc_ak0_m_ak1)), AElementwiseOperation, @@ -628,7 +632,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< ThisThreadBlock, BsDataType, - Tuple, + Tuple, decltype(bs_grid_desc_bk0_n_bk1), decltype(tie(b_block_desc_bk0_n_bk1)), BElementwiseOperation, @@ -656,14 +660,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - constexpr index_t KPack = - math::max(math::lcm(AK1, BK1), - MfmaSelector::selected_mfma.k_per_blk); + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), + MfmaSelector::selected_mfma + .k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeDataType, // ComputeDataType for A - ComputeDataType, // ComputeDataType for B + AComputeDataType, + BComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), @@ -681,10 +686,10 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - static_cast(p_shared) + a_block_space_size_aligned, + static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index c0a3d29f8..6ddc3aca1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -73,7 +73,7 @@ template + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleD_xdl_cshuffle { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle #if CK_WORKAROUND_DENORM_FIX using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; #else using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 0f845ca1e..40878e4f0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -290,6 +290,42 @@ using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8, BF8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 24a5f9a5c..e61ec2828 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -34,7 +34,8 @@ template + typename AComputeType, + typename BComputeType> struct DeviceOperationInstanceFactory> + AComputeType, + BComputeType>> { using DeviceOp = DeviceGroupedConvFwdMultipleABD; + AComputeType, + BComputeType>; static auto GetInstances() { @@ -75,14 +78,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } @@ -94,14 +99,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -115,14 +122,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs); } @@ -130,14 +139,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs); } @@ -149,14 +161,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); } @@ -164,7 +178,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs); } @@ -176,14 +192,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs); } @@ -191,7 +209,9 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } @@ -203,14 +223,16 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs); } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); } @@ -218,14 +240,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs); } @@ -237,7 +262,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); } @@ -245,28 +271,40 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances( op_ptrs); } if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs); } #endif #ifdef CK_ENABLE_BF8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs); } #endif +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); } @@ -274,14 +312,17 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && + is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs); } #endif #ifdef CK_ENABLE_INT8 if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs); } @@ -295,7 +336,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(op_ptrs); @@ -305,7 +347,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(op_ptrs); @@ -320,7 +363,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(op_ptrs); add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(op_ptrs); @@ -335,7 +379,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances( @@ -347,7 +392,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(op_ptrs); @@ -363,7 +409,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances( @@ -375,7 +422,8 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v && + is_same_v) { add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(op_ptrs); add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index 942674ef9..691414ebc 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -351,6 +351,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( BF8>>>& instances); #endif +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 972fb5403..50a6ec9a4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -41,4 +41,9 @@ if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp) endif() +if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp new file mode 100644 index 000000000..d42104bf6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index f629809da..d91387330 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -31,7 +31,9 @@ template + typename OutDataType, + typename AComputeType = InDataType, + typename BComputeType = AComputeType> bool profile_grouped_conv_fwd_impl(int do_verification, int init_method, bool do_log, @@ -209,7 +211,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>; + OutElementOp, + AComputeType, + BComputeType>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 1f7273372..a847999b5 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -25,6 +25,7 @@ enum struct ConvDataType INT8_INT8_INT8, // 3 F8_F8_F8, // 4 BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 }; #define OP_NAME "grouped_conv_fwd" @@ -40,7 +41,8 @@ static void print_helper_msg() << " 2: Input bf16, Weight bf16, Output bf16\n" << " 3: Input int8, Weight int8, Output int8\n" << " 4: Input fp8, Weight fp8, Output fp8\n" - << " 5: Input bf8, Weight bf8, Output fp8)\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -118,7 +120,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) auto out_layout, auto in_type, auto wei_type, - auto out_type) { + auto out_type, + auto a_compute_type, + auto b_compute_type) { constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; using InLayout = decltype(in_layout); @@ -129,13 +133,18 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using WeiDataType = decltype(wei_type); using OutDataType = decltype(out_type); + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + bool pass = ck::profiler::profile_grouped_conv_fwd_impl( + OutDataType, + AComputeType, + BComputeType>( do_verification, init_method, do_log, time_kernel, params); return pass ? 0 : 1; @@ -146,57 +155,59 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}); + return profile(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}); + return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}); + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}); + return profile( + I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}); + return profile( + I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } // NHWGC_GKYXC_NHWGK @@ -204,65 +215,71 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}); + return profile(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}); + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } } else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) { if(data_type == ConvDataType::F32_F32_F32) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); } else if(data_type == ConvDataType::F16_F16_F16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); } else if(data_type == ConvDataType::BF16_BF16_BF16) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}); + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } else if(data_type == ConvDataType::INT8_INT8_INT8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}); + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } else if(data_type == ConvDataType::F8_F8_F8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{}); } else if(data_type == ConvDataType::BF8_BF8_F8) { - return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}); + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{}); } } -- GitLab From c701071666ce5656c8bd4331979f56fcc497fda6 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 4 Apr 2024 11:01:33 +0200 Subject: [PATCH 18/63] Add Grouped Gemm Multiple D SplitK TwoStage (#1212) * Support A/B/C elementwise ops. * First part of GGEMM multiD splitk two stage. * WIP - changes for debuggin. * tmp save * working version * added bf16@int8 version * fixes * add reviewers sugestions * pre-commited missing files * switched to ifs from elseifs --------- Co-authored-by: Adam Osewski --- ...rouped_gemm_multiple_d_splitk_xdl_fp16.cpp | 394 +++++++ .../device_grouped_gemm_multiple_d_splitk.hpp | 136 +++ .../device/impl/device_elementwise_impl.hpp | 16 +- ...ltiple_d_splitk_xdl_cshuffle_two_stage.hpp | 987 ++++++++++++++++++ ...evice_grouped_gemm_xdl_splitk_cshuffle.hpp | 27 +- .../cpu/reference_gemm_multiple_d.hpp | 175 ++++ .../gpu/grouped_gemm.hpp | 50 +- .../gpu/grouped_gemm/CMakeLists.txt | 2 + ...o_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp | 99 ++ ...wo_stage_f16_f16_f16_mk_kn_mn_instance.cpp | 96 ++ .../profile_grouped_gemm_two_stage_impl.hpp | 366 +++++++ profiler/src/CMakeLists.txt | 1 + .../src/profile_grouped_gemm_two_stage.cpp | 157 +++ 13 files changed, 2490 insertions(+), 16 deletions(-) create mode 100644 example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm_two_stage.cpp diff --git a/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp new file mode 100644 index 000000000..ecff7b471 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_multiple_d_splitk_xdl_fp16.cpp @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include +#include + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm_multiple_d.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddAdd = ck::tensor_operation::element_wise::AddAdd; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DDataType = F16; +using DsDataType = ck::Tuple; +using EDataType = F32; + +using ALayout = Row; +using BLayout = Col; +using DLayout = Row; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = AddAdd; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr int NumDMatrices = 2; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector> stride_Ds; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 128; + bool time_kernel = true; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + std::vector p_As; + std::vector p_Bs; + std::vector> p_Ds = {}; + + gemm_descs.reserve(group_count); + p_As.reserve(group_count); + p_Bs.reserve(group_count); + p_Ds.reserve(group_count); + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector, NumDMatrices>> d_tensors; + std::vector> c_host_tensors; + std::vector> c_device_result_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + d_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_result_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + std::vector> d_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + d_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + + auto d0_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + auto d1_tensor = Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], DLayout{})); + + std::array, NumDMatrices> d_tens = {d0_tensor, d1_tensor}; + d_tensors.push_back(d_tens); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_result_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc + << " c_m_n: " << c_device_result_tensors[i].mDesc << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].GetElementSize() + + sizeof(BDataType) * b_tensors[i].GetElementSize() + + sizeof(DDataType) * d_tensors[i][0].GetElementSize() * NumDMatrices + + sizeof(EDataType) * c_device_result_tensors[i].GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors[i][j].GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + } + } + } + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(a_tensors[i].GetElementSpaceSize() * sizeof(ADataType))); + + b_tensors_device.emplace_back( + std::make_unique(b_tensors[i].GetElementSpaceSize() * sizeof(BDataType))); + + c_tensors_device.emplace_back(std::make_unique( + c_device_result_tensors[i].GetElementSpaceSize() * sizeof(EDataType))); + + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i].emplace_back(std::make_unique( + d_tensors[i][j].GetElementSpaceSize() * sizeof(DDataType))); + } + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); + for(int j = 0; j < NumDMatrices; ++j) + { + d_tensors_device[i][j]->ToDevice(d_tensors[i][j].mData.data()); + } + c_tensors_device[i]->SetZero(); + p_As.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_Bs.push_back(b_tensors_device[i]->GetDeviceBuffer()); + p_Ds.push_back( + {d_tensors_device[i][0]->GetDeviceBuffer(), d_tensors_device[i][1]->GetDeviceBuffer()}); + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + gemm_descs.push_back({problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + problem_size.stride_Cs[i], + problem_size.stride_Ds[i]}); + } + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, cde_element_op); + gemm.SetKBatchSize(argument, config.k_batch); + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + gemm.SetDeviceKernelArgs(argument, gemm_arg_dev_mem.GetDeviceBuffer()); + + invoker.Run(argument, StreamConfig{nullptr, false, 1}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmMultipleD; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + auto karg = argument.gemm_kernel_args_[i].karg_; + auto dev_res_tensor = + Tensor(f_host_tensor_descriptor(karg.M, karg.N, karg.StrideC, ELayout{})); + + c_tensors_device[i]->FromDevice(c_device_result_tensors[i].mData.data(), + c_device_result_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + d_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + cde_element_op); + + ref_invoker.Run(ref_argument); + pass &= ck::utils::check_err(c_device_result_tensors[i], c_host_tensors[i]); + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + } + + return pass; +} + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + if(argc < 11) + { + std::vector Ms{64, 127, 255, 129, 260, 190, 77}; + problem_size.group_count = Ms.size(); + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(Ms[i]); + problem_size.Ns.push_back(252); + problem_size.Ks.push_back(4608); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + + problem_size.stride_Ds.push_back({}); + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds[i].push_back(problem_size.Ns[i]); + } + } + + std::cout + << "Usage:\n" + << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: time kernel (0=n0, 1=yes)\n" + << "arg4 to 9: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg10: k_batch (> 0)\n" + << "... setting default values." << std::endl; + } + else + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[10]); + + problem_size.Ms = argToIntArray(argv[4]); + problem_size.Ns = argToIntArray(argv[5]); + problem_size.Ks = argToIntArray(argv[6]); + + problem_size.stride_As = argToIntArray(argv[7]); + problem_size.stride_Bs = argToIntArray(argv[8]); + problem_size.stride_Cs = argToIntArray(argv[9]); + + for(int j = 0; j < NumDMatrices; ++j) + { + problem_size.stride_Ds.push_back(problem_size.stride_Cs); + } + + problem_size.group_count = problem_size.Ms.size(); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp new file mode 100644 index 000000000..d91eac073 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmMultipleDKernelArguments +{ + __host__ __device__ + GroupedGemmMultipleDKernelArguments(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + +template +struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm +{ + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0; + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp index 37867f1ea..1a44c3ed9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,10 +22,12 @@ namespace device { template + index_t NumDim, // The max dim of input tensors + // the tensors descs have to be aligned, such that + // the innermost dim is the contiguous one. + index_t MPerThread, // How many elements per thread to read + typename InScalarPerVectorSeq, // Scalar per vec for each Input + typename OutScalarPerVectorSeq> // Scalar per vec for each Output struct DeviceElementwiseImpl : public DeviceElementwise { @@ -242,13 +244,13 @@ struct DeviceElementwiseImpl static_for<0, NumInput, 1>{}([&](auto I) { if(!IsScalarPerVectorValid( arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) - valid = false; + valid = valid && false; }); static_for<0, NumOutput, 1>{}([&](auto I) { if(!IsScalarPerVectorValid( arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) - valid = false; + valid = valid && false; }); return valid; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp new file mode 100644 index 000000000..2d60c027b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -0,0 +1,987 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template = false> +struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + : public DeviceGroupedGemmMultipleDSplitK +{ + using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1 + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using WorkspaceDataType = float; + + // First stage GridwiseGEMM kernel. + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + WorkspaceDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, // CElementwiseOperation + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeDataType>; + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&]([[maybe_unused]] auto i) constexpr { + return Number{}; + }, + Number{}); + } + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using DsGridPointer = decltype(MakeDsGridPointer()); + using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); + using CDDataTypes = decltype(concat_tuple(ck::Tuple{}, DsGridPointer{})); + + using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence()); + + static constexpr index_t ClusterLengthMPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + ElementwiseInputSequence, + ck::Sequence, + true>; + + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using GemmKernelArgument = typename GridwiseGemm::Argument; + + struct GemmTransKernelArg + { + GemmKernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(GemmKernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + cde_element_op, + DefaultKBatch) + { + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t kbatch) + : K_BATCH{kbatch}, + group_count_{0}, + skipped_group_count_{0}, + grid_size_{0}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + p_Ds_{p_Ds} + { + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size"); + } + + gemm_kernel_args_.reserve(group_count_); + elementwise_c_grid_descs_m_n_.reserve(group_count_); + elementwise_d_grid_descs_m_n_.reserve(group_count_); + ds_grid_pointer_.reserve(group_count_); + group_grid_size_.reserve(group_count_); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M * N * K == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_e = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e); + + DsGridDesc_M_N ds_grid_desc_m_n; + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + group_grid_size_[i] = grid_size_grp; + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + std::array stride_ds; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + stride_ds[j] = gemm_descs[i].stride_Ds_[j]; + }); + stride_Ds_.emplace_back(std::move(stride_ds)); + + // We first set E pointer to actual operation output, but later on + // when workspace will be set, this will be updated to workspace memory. + auto karg = GemmKernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_e, + m_padded, + n_padded, + k_padded, + k0_padded, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + + elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); + elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); + ds_grid_pointer_.push_back(p_ds_grid); + } + // Store a copy of E pointers for elementwise kernel destination + e_ptrs_ = p_Es; + } + + /** + * @brief Set new kbatch value. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + group_grid_size_[i] = grid_size_grp; + karg.KPadded = k_padded; + karg.K0Padded = k0_padded; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + +#if DEBUG_LOG + index_t tiles = (block_end - block_start) / K_BATCH; + std::cout << "block_start: " << block_start << "\n" + << "block_end: " << block_end << "\n" + << "tiles: " << tiles << std::endl + << std::endl; + + std::cout << "KPadded: " << karg.KPadded << std::endl + << "K0Padded: " << karg.K0Padded << std::endl + << "KBatch: " << karg.k_batch << std::endl + << "grid_size_: " << karg.KPadded << std::endl; +#endif + } + } + + void UpdateEPointers() + { + // set-up each group E pointer to it's designated workspace memory. + WorkspaceDataType* p_workspace = reinterpret_cast(p_workspace_); + std::size_t offset = 0; + + for(auto& arg : gemm_kernel_args_) + { + arg.karg_.p_c_grid = p_workspace + offset; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + offset += tiles * MPerBlock * NPerBlock; +#if DEBUG_LOG + std::cout << "block_start: " << arg.block_start_ << "\n" + << "block_end: " << arg.block_end_ << "\n" + << "tiles: " << tiles << "\n" + << "offset: " << offset << std::endl; +#endif + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + std::size_t size_bytes{0}; + + for(const auto& arg : gemm_kernel_args_) + { + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType); + } + return size_bytes; + } + + std::size_t GetWorkspaceSize(std::size_t group) const + { + const auto& arg = gemm_kernel_args_[group]; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + return tiles * MPerBlock * NPerBlock; + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + index_t grid_size_; + // Pointer to device memory with GEMM kernel arguments. + const void* p_dev_gemm_args_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + std::vector>& p_Ds_; + std::vector> stride_Ds_; + std::vector gemm_kernel_args_; + std::vector group_grid_size_; + + std::vector elementwise_c_grid_descs_m_n_; + std::vector elementwise_d_grid_descs_m_n_; + std::vector ds_grid_pointer_; + std::vector e_ptrs_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary + /// workspace. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config = StreamConfig{}) + { + auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = + CheckArgument(arg, stream_config); + + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(dev_gemm_workspace == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + + if(all_have_main_k_block_loop) + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + else + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see + /// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly + /// allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(arg.p_workspace_ == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, arg.p_workspace_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const + { + bool all_have_kbatch_gt_one, all_have_main_k_block_loop; + + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1( + arg.gemm_kernel_args_[0].karg_.M, + arg.gemm_kernel_args_[0].karg_.MPadded, + arg.gemm_kernel_args_[0].karg_.K, + arg.gemm_kernel_args_[0].karg_.StrideA, + arg.gemm_kernel_args_[0].karg_.k_batch, + arg.gemm_kernel_args_[0].karg_.K0Padded, + arg.gemm_kernel_args_[0].karg_.KPadded); + + all_have_kbatch_gt_one = arg.K_BATCH > 1; + all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + } + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M, + gemm_arg.MPadded, + gemm_arg.K, + gemm_arg.StrideA, + gemm_arg.k_batch, + gemm_arg.K0Padded, + gemm_arg.KPadded); + + bool not_all_have_main_k_block_loop_same = + all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + bool not_all_have_kbatch_value_same = + all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1); + + if(not_all_have_main_k_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << gemm_arg.k_batch + << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop); + } + + template + float DispatchKernel(const Argument& arg, + const void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + const auto gemm_kernel = + kernel_grouped_gemm_xdl_splitk; + + const auto elementwise_kernel = kernel_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation>; + return LaunchKernel(gemm_kernel, + elementwise_kernel, + arg, + dev_gemm_args, + dev_gemm_workspace, + stream_config); + } + + template + float LaunchKernel(const KernelFunction& gemm_kernel, + const KernelFunction2& elementwise_kernel, + const Argument& arg, + const void* dev_gemm_args, + [[maybe_unused]] void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + float time{0.f}; + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + // GEMM kernel + time = launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + gemm_kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + PassThrough{}); + + // Elementwise kernels + for(int i = 0; i < arg.group_count_; ++i) + { + time += launch_and_time_kernel( + stream_config, + elementwise_kernel, + dim3(arg.group_grid_size_[i]), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + arg.elementwise_d_grid_descs_m_n_[i]), + make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), + arg.ds_grid_pointer_[i]), + type_convert(arg.e_ptrs_[i]), + Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0), + arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)}, + arg.cde_element_op_); + } + return time; + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { +#if DEBUG_LOG + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; +#endif // DEBUG_LOG + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + + bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); + if(not group_arg_valid) + { +#if DEBUG_LOG + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gemm_arg.Print(); +#endif // DEBUG_LOG + } + supported = supported && group_arg_valid; + } + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpy(p_dev_kernel_args, + arg.gemm_kernel_args_.data(), + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + void SetWorkSpacePointer( + BaseArgument* p_arg, + void* p_workspace, + [[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + p_arg_->UpdateEPointers(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + return SetKBatchSize(*dynamic_cast(p_arg), kbatch); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->gemm_kernel_args_.size() * + sizeof(GemmTransKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index abee2fea5..a33d7d8fb 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -26,13 +26,19 @@ namespace device { template + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, + typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ defined(__gfx94__)) @@ -64,10 +70,16 @@ __global__ void GridwiseGemm::template Run( gemm_desc_ptr[group_id].karg_, static_cast(p_shared), - gemm_desc_ptr[group_id].block_2_ctile_map_); + gemm_desc_ptr[group_id].block_2_ctile_map_, + a_element_op, + b_element_op, + c_element_op); #else ignore = gemm_descs_const; ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK; using KernelArgument = typename GridwiseGemm::Argument; - + using PassThrough = ck::tensor_operation::element_wise::PassThrough; struct GemmTransKernelArg { KernelArgument karg_; @@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +// assumption: every D matrix has the same layout and the same datatype +template +struct ReferenceGemmMultipleD : public device::BaseOperator +{ + using DDataType = remove_cvref_t>; + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const std::array, DsDataType::Size()>& ds_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + ds_m_n_{ds_m_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const std::array, DsDataType::Size()>& ds_m_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmMultipleD::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + ComputeTypeA v_a = 0; + ComputeTypeB v_b = 0; + + for(int k = 0; k < K; ++k) + { + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + CDataType v_c = 0; + + if constexpr(DsDataType::Size() == 0) + { + arg.cde_element_op_(v_c, v_acc); + } + else if constexpr(DsDataType::Size() == 1) + { + arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n)); + } + else if constexpr(DsDataType::Size() == 2) + { + arg.cde_element_op_(v_c, v_acc, arg.ds_m_n_[0](m, n), arg.ds_m_n_[1](m, n)); + } + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const std::array, DsDataType::Size()>& ds_m_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{a_m_k, b_k_n, ds_m_n, c_m_n, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmMultipleD" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index 056e906c2..d06a57981 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -146,6 +146,32 @@ void add_device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances); + template > op_ptrs; +#if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && is_same_v) { @@ -190,6 +217,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) @@ -210,8 +239,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v) +#endif +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) + if constexpr(is_same_v && is_same_v && + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) @@ -228,6 +259,19 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + op_ptrs); + } + } +#endif return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index 2625e6cbe..5a50eca10 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -10,4 +10,6 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_irregular_instance.cpp device_grouped_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_irregular_instance.cpp + device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..8d3baf19e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_generic_instances = + std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, BF16, I8, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1> + // clang-format on + >; + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_multiple_d_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_generic_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 000000000..d38484234 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_multiple_d_splitk_xdl_two_stage_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Instances having AK1!=BK1 are temporarily disabled and will be re-enabled in future +// a[m, k] * b[k, n] = e[m, n] +using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_generic_instances = + std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 1, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1> + // clang-format on + >; + +using device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#################################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#################################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#################################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 192, 64, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 48, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 192, 32, 8, 8, 32, 32, 1, 3, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 24, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 192, 32, 32, 8, 8, 32, 32, 3, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 64, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1>, + DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< Row, Row, Empty_Tuple, Row, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 4, PipelineVersion::v1> + // clang-format on + >; + +void add_device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_instances{}); + add_device_operation_instances( + instances, + device_grouped_gemm_multiple_d_xdl_two_stage_f16_f16_f16_mk_kn_mn_generic_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp new file mode 100644 index 000000000..41dcabbfc --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_two_stage_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_host_results; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + c_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); +#if DEBUG_LOG + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; +#endif // DEBUG_LOG + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemm, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + auto p_ds = std::vector>{}; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_results[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_c, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + using DeviceOpSplitK = + ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + // skip non-splitk grouped_gemm + if(dynamic_cast(gemm_ptr.get()) == nullptr) + { + continue; + } + + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; + + if(kbatch > 0) + { + kbatch_list = {kbatch}; + } + + for(std::size_t j = 0; j < kbatch_list.size(); j++) + { + + auto kbatch_curr = kbatch_list[j]; + dynamic_cast(gemm_ptr.get()) + ->SetKBatchSize(argument_ptr.get(), kbatch_curr); + + DeviceMem gemm_arg_dev_mem(dynamic_cast(gemm_ptr.get()) + ->GetDeviceKernelArgSize(argument_ptr.get())); + dynamic_cast(gemm_ptr.get()) + ->SetDeviceKernelArgs(argument_ptr.get(), gemm_arg_dev_mem.GetDeviceBuffer()); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + gemm_desc_workspace.SetZero(); + for(std::size_t i = 0; i < gemm_descs.size(); i++) + c_device_buf[i]->SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + if(std::is_same_v && kbatch_curr > 1) + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + 0.06); + } + else + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i]); + } + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index cb6ffbec6..e8992070b 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -40,6 +40,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_silu.cpp) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_two_stage.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) endif() list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) diff --git a/profiler/src/profile_grouped_gemm_two_stage.cpp b/profiler/src/profile_grouped_gemm_two_stage.cpp new file mode 100644 index 000000000..17daf1e80 --- /dev/null +++ b/profiler/src/profile_grouped_gemm_two_stage.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_two_stage_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 +}; + +enum struct GemmDataType +{ + F16_F16_F16, // 0 + BF16_INT8_BF16 // 1 +}; + +#define OP_NAME "grouped_gemm_two_stage" +#define OP_DESC "Grouped GEMM TwoStage" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm_two_stage(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: fp16; 1: bf16@int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n]);\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg15: kbatch value (default 1)\n" + << "optional:\n" + << "arg16: number of warm-up cycles (default 1)\n" + << "arg17: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + auto StrideAs = argToIntArray(argv[11]); + auto StrideBs = argToIntArray(argv[12]); + auto StrideCs = argToIntArray(argv[13]); + const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; + + const int DefaultStrideA = Ks[0]; + const int DefaultStrideB = Ns[0]; + const int DefaultStrideC = Ns[0]; + + for(size_t i = 0; i < Ms.size(); ++i) + { + StrideAs[i] = StrideAs[i] == -1 ? DefaultStrideA : StrideAs[i]; + StrideBs[i] = StrideBs[i] == -1 ? DefaultStrideB : StrideBs[i]; + StrideCs[i] = StrideCs[i] == -1 ? DefaultStrideC : StrideCs[i]; + } + + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + } + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_two_stage_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::BF16_INT8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_two_stage_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_two_stage); -- GitLab From 7e5c81fed2737312f960cd41fe9afbc02669ce27 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 4 Apr 2024 11:33:29 -0700 Subject: [PATCH 19/63] fix the latest errors with staging compiler (#1229) --- test/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 720ab468e..bbb75c49e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -140,6 +140,7 @@ function(add_gtest_executable TEST_NAME) set(result ${result} PARENT_SCOPE) endfunction() +add_compile_options(-Wno-c++20-extensions) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) add_subdirectory(conv_util) -- GitLab From 50cc0a13a69aa6bd59a35bc84759688c3552cc45 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 9 Apr 2024 13:56:54 -0500 Subject: [PATCH 20/63] Add an example (#1225) --- example/09_convnd_fwd/CMakeLists.txt | 1 + .../convnd_fwd_xdl_fp16_comp_fp8.cpp | 81 +++++++++++++++++++ 2 files changed, 82 insertions(+) create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index afbe74121..778e81872 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -5,6 +5,7 @@ add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp new file mode 100644 index 000000000..346ab8d95 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16_comp_fp8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; +using ComputeType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } -- GitLab From 366592b0ff47f1fe986d08c4b78fa8c87f6a3751 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 9 Apr 2024 13:57:32 -0500 Subject: [PATCH 21/63] Add an example (#1227) --- .../CMakeLists.txt | 3 ++ .../common.hpp | 2 + ...ed_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp | 38 +++++++++++++++++++ 3 files changed, 43 insertions(+) create mode 100644 example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp diff --git a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt index 72e695964..ce951f635 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt +++ b/example/38_grouped_conv_bwd_data_multiple_d/CMakeLists.txt @@ -3,6 +3,9 @@ add_custom_target(example_grouped_conv_bwd_data) add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16) +add_example_executable(example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8 grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp) +add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8) + add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp) add_example_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16) diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index ebb1c606c..8a0474156 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -34,6 +34,8 @@ static constexpr auto ConvBwdDataDefault = using FP16 = ck::half_t; using FP32 = float; +using FP8 = ck::f8_t; +using BF8 = ck::bf8_t; struct ExecutionConfig final { diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp new file mode 100644 index 000000000..41023ef82 --- /dev/null +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "common.hpp" + +using OutDataType = FP16; +using WeiDataType = FP16; +using AccDataType = FP32; +using CShuffleDataType = FP16; +using DsDataType = ck::Tuple<>; +using InDataType = FP16; +using AComputeType = BF8; +using BComputeType = FP8; + +using OutLayout = ck::tensor_layout::convolution::GNHWK; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using DsLayout = ck::Tuple<>; +using InLayout = ck::tensor_layout::convolution::GNHWC; + +using OutElementOp = PassThrough; +using WeiElementOp = PassThrough; +using InElementOp = PassThrough; + +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); + +// clang-format off +using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 +// ######| NDimSpatial| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| Loop| ACompute| BCompute| +// ######| | | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| Scheduler| Type| Type| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < NDimSpatial, OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, AComputeType, BComputeType>; +// clang-format on + +#include "run_grouped_conv_bwd_data_example.inc" + +int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } -- GitLab From ced5af16f7ac072315d9fb270ac86b77cff6f8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 9 Apr 2024 23:46:21 +0200 Subject: [PATCH 22/63] Extend support for contraction 6D (#1207) * Extend support for contraction up to 5D * Extend contraction bilinear instances * Fix interface test * Add 6d support, remove 3d,4d,5d * Fixes * Fix readme * Make defualt dim for contraction instances --- .../cpu/reference_contraction.hpp | 138 +++++- .../device_contraction_instance.hpp | 252 +++++----- .../gpu/contraction_bilinear.hpp | 464 +++++++++++++++++- .../gpu/contraction_scale.hpp | 464 +++++++++++++++++- .../ck/library/utility/host_tensor.hpp | 33 +- ...16_bf16_bf16_compute_f32_kknn_instance.cpp | 5 +- ...16_bf16_bf16_compute_f32_knnn_instance.cpp | 5 +- ...16_bf16_bf16_compute_f32_mknn_instance.cpp | 5 +- ...16_bf16_bf16_compute_f32_mnnn_instance.cpp | 5 +- ..._f16_f16_f16_compute_f32_kknn_instance.cpp | 5 +- ..._f16_f16_f16_compute_f32_knnn_instance.cpp | 5 +- ..._f16_f16_f16_compute_f32_mknn_instance.cpp | 5 +- ..._f16_f16_f16_compute_f32_mnnn_instance.cpp | 5 +- ...f32_f32_f32_compute_bf16_kknn_instance.cpp | 5 +- ...f32_f32_f32_compute_bf16_knnn_instance.cpp | 5 +- ...f32_f32_f32_compute_bf16_mknn_instance.cpp | 5 +- ...f32_f32_f32_compute_bf16_mnnn_instance.cpp | 5 +- ..._f32_f32_f32_compute_f16_kknn_instance.cpp | 5 +- ..._f32_f32_f32_compute_f16_knnn_instance.cpp | 5 +- ..._f32_f32_f32_compute_f16_mknn_instance.cpp | 5 +- ..._f32_f32_f32_compute_f16_mnnn_instance.cpp | 5 +- ..._shuffle_f32_f32_f32_f32_kknn_instance.cpp | 5 +- ..._shuffle_f32_f32_f32_f32_knnn_instance.cpp | 5 +- ..._shuffle_f32_f32_f32_f32_mknn_instance.cpp | 5 +- ..._shuffle_f32_f32_f32_f32_mnnn_instance.cpp | 5 +- ..._f64_f64_f64_compute_f32_kknn_instance.cpp | 5 +- ..._f64_f64_f64_compute_f32_knnn_instance.cpp | 5 +- ..._f64_f64_f64_compute_f32_mknn_instance.cpp | 5 +- ..._f64_f64_f64_compute_f32_mnnn_instance.cpp | 5 +- ..._shuffle_f64_f64_f64_f64_kknn_instance.cpp | 5 +- ..._shuffle_f64_f64_f64_f64_knnn_instance.cpp | 5 +- ..._shuffle_f64_f64_f64_f64_mknn_instance.cpp | 5 +- ..._shuffle_f64_f64_f64_f64_mnnn_instance.cpp | 5 +- ...16_bf16_bf16_compute_f32_kknn_instance.cpp | 58 +++ ...16_bf16_bf16_compute_f32_knnn_instance.cpp | 58 +++ ...16_bf16_bf16_compute_f32_mknn_instance.cpp | 58 +++ ...16_bf16_bf16_compute_f32_mnnn_instance.cpp | 58 +++ ..._f16_f16_f16_compute_f32_kknn_instance.cpp | 58 +++ ..._f16_f16_f16_compute_f32_knnn_instance.cpp | 58 +++ ..._f16_f16_f16_compute_f32_mknn_instance.cpp | 58 +++ ..._f16_f16_f16_compute_f32_mnnn_instance.cpp | 58 +++ ...f32_f32_f32_compute_bf16_kknn_instance.cpp | 58 +++ ...f32_f32_f32_compute_bf16_knnn_instance.cpp | 58 +++ ...f32_f32_f32_compute_bf16_mknn_instance.cpp | 58 +++ ...f32_f32_f32_compute_bf16_mnnn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_f16_kknn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_f16_knnn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_f16_mknn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_f16_mnnn_instance.cpp | 58 +++ ..._shuffle_f32_f32_f32_f32_kknn_instance.cpp | 58 +++ ..._shuffle_f32_f32_f32_f32_knnn_instance.cpp | 58 +++ ..._shuffle_f32_f32_f32_f32_mknn_instance.cpp | 58 +++ ..._shuffle_f32_f32_f32_f32_mnnn_instance.cpp | 58 +++ ..._f64_f64_f64_compute_f32_kknn_instance.cpp | 58 +++ ..._f64_f64_f64_compute_f32_knnn_instance.cpp | 58 +++ ..._f64_f64_f64_compute_f32_mknn_instance.cpp | 58 +++ ..._f64_f64_f64_compute_f32_mnnn_instance.cpp | 58 +++ ..._shuffle_f64_f64_f64_f64_kknn_instance.cpp | 58 +++ ..._shuffle_f64_f64_f64_f64_knnn_instance.cpp | 58 +++ ..._shuffle_f64_f64_f64_f64_mknn_instance.cpp | 58 +++ ..._shuffle_f64_f64_f64_f64_mnnn_instance.cpp | 58 +++ .../gpu/contraction_bilinear/CMakeLists.txt | 82 ++-- ...f16_bf16_bf16_compute_f32_kkn_instance.cpp | 5 +- ...f16_bf16_bf16_compute_f32_knn_instance.cpp | 5 +- ...f16_bf16_bf16_compute_f32_mkn_instance.cpp | 5 +- ...f16_bf16_bf16_compute_f32_mnn_instance.cpp | 5 +- ...e_f16_f16_f16_compute_f32_kkn_instance.cpp | 5 +- ...e_f16_f16_f16_compute_f32_knn_instance.cpp | 5 +- ...e_f16_f16_f16_compute_f32_mkn_instance.cpp | 5 +- ...e_f16_f16_f16_compute_f32_mnn_instance.cpp | 5 +- ..._f32_f32_f32_compute_bf16_kkn_instance.cpp | 5 +- ..._f32_f32_f32_compute_bf16_knn_instance.cpp | 5 +- ..._f32_f32_f32_compute_bf16_mkn_instance.cpp | 5 +- ..._f32_f32_f32_compute_bf16_mnn_instance.cpp | 5 +- ...e_f32_f32_f32_compute_f16_kkn_instance.cpp | 5 +- ...e_f32_f32_f32_compute_f16_knn_instance.cpp | 5 +- ...e_f32_f32_f32_compute_f16_mkn_instance.cpp | 5 +- ...e_f32_f32_f32_compute_f16_mnn_instance.cpp | 5 +- ...xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp | 5 +- ...xdl_c_shuffle_f32_f32_f32_knn_instance.cpp | 5 +- ...xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp | 5 +- ...xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp | 5 +- ...e_f64_f64_f64_compute_f32_kkn_instance.cpp | 5 +- ...e_f64_f64_f64_compute_f32_knn_instance.cpp | 5 +- ...e_f64_f64_f64_compute_f32_mkn_instance.cpp | 5 +- ...e_f64_f64_f64_compute_f32_mnn_instance.cpp | 5 +- ...xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp | 5 +- ...xdl_c_shuffle_f64_f64_f64_knn_instance.cpp | 5 +- ...xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp | 5 +- ...xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp | 5 +- ...f16_bf16_bf16_compute_f32_kkn_instance.cpp | 58 +++ ...f16_bf16_bf16_compute_f32_knn_instance.cpp | 58 +++ ...f16_bf16_bf16_compute_f32_mkn_instance.cpp | 58 +++ ...f16_bf16_bf16_compute_f32_mnn_instance.cpp | 58 +++ ...e_f16_f16_f16_compute_f32_kkn_instance.cpp | 58 +++ ...e_f16_f16_f16_compute_f32_knn_instance.cpp | 58 +++ ...e_f16_f16_f16_compute_f32_mkn_instance.cpp | 58 +++ ...e_f16_f16_f16_compute_f32_mnn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_bf16_kkn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_bf16_knn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_bf16_mkn_instance.cpp | 58 +++ ..._f32_f32_f32_compute_bf16_mnn_instance.cpp | 58 +++ ...e_f32_f32_f32_compute_f16_kkn_instance.cpp | 58 +++ ...e_f32_f32_f32_compute_f16_knn_instance.cpp | 58 +++ ...e_f32_f32_f32_compute_f16_mkn_instance.cpp | 58 +++ ...e_f32_f32_f32_compute_f16_mnn_instance.cpp | 58 +++ ...xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp | 57 +++ ...xdl_c_shuffle_f32_f32_f32_knn_instance.cpp | 57 +++ ...xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp | 57 +++ ...xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp | 57 +++ ...e_f64_f64_f64_compute_f32_kkn_instance.cpp | 58 +++ ...e_f64_f64_f64_compute_f32_knn_instance.cpp | 58 +++ ...e_f64_f64_f64_compute_f32_mkn_instance.cpp | 58 +++ ...e_f64_f64_f64_compute_f32_mnn_instance.cpp | 58 +++ ...xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp | 57 +++ ...xdl_c_shuffle_f64_f64_f64_knn_instance.cpp | 57 +++ ...xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp | 57 +++ ...xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp | 57 +++ .../gpu/contraction_scale/CMakeLists.txt | 82 ++-- profiler/README.md | 24 +- .../profiler/profile_contraction_impl.hpp | 80 ++- .../profiler/profile_contraction_utils.hpp | 28 +- profiler/src/profile_contraction_bilinear.cpp | 154 +++--- profiler/src/profile_contraction_scale.cpp | 147 ++++-- .../test_contraction_interface_xdl.cpp | 14 +- test/contraction/test_contraction_xdl.cpp | 138 ++++-- 126 files changed, 5051 insertions(+), 569 deletions(-) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_bilinear/{ => 2D}/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp (94%) create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp (94%) rename library/src/tensor_operation_instance/gpu/contraction_scale/{ => 2D}/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp (94%) create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp index 527dac6d3..38557e349 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -15,7 +15,6 @@ namespace ck { namespace tensor_operation { namespace host { -// hardcoded for NumDimM == NumDimN == NumDimK == 2 template = false> + ck::enable_if_t<(NumDimM == 2 || NumDimM == 6) && (NumDimN == 2 || NumDimN == 6) && + (NumDimK == 2 || NumDimK == 6), + bool> = false> struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::BaseOperator { // Argument @@ -60,9 +61,28 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base float Run(const Argument& arg) { - auto f_ms_ns = [&](auto m0, auto m1, auto n0, auto n1) { - const ck::index_t K0 = arg.a_ms_ks_.mDesc.GetLengths()[2]; - const ck::index_t K1 = arg.a_ms_ks_.mDesc.GetLengths()[3]; + auto f_ms_ns = [&](auto m0, + auto m1, + auto m2, + auto m3, + auto m4, + auto m5, + auto n0, + auto n1, + auto n2, + auto n3, + auto n4, + auto n5) { + const ck::index_t K0 = arg.a_ms_ks_.mDesc.GetLengths()[NumDimM]; + const ck::index_t K1 = arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 1]; + const ck::index_t K2 = + NumDimK >= 3 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 2] : 1; + const ck::index_t K3 = + NumDimK >= 4 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 3] : 1; + const ck::index_t K4 = + NumDimK >= 5 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 4] : 1; + const ck::index_t K5 = + NumDimK >= 6 ? arg.a_ms_ks_.mDesc.GetLengths()[NumDimM + 5] : 1; AccDataType v_acc = 0; @@ -70,32 +90,96 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base { for(ck::index_t k1 = 0; k1 < K1; ++k1) { - // Simulate the possible casting when ComputeDataType is different than the - // A/B data types - ComputeDataType v_a_compute_input = - ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1)); - ComputeDataType v_b_compute_input = - ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1)); - - AccDataType v_a; - AccDataType v_b; - - arg.a_element_op_(v_a, ck::type_convert(v_a_compute_input)); - arg.b_element_op_(v_b, ck::type_convert(v_b_compute_input)); - - v_acc += v_a * v_b; + for(ck::index_t k2 = 0; k2 < K2; ++k2) + { + for(ck::index_t k3 = 0; k3 < K3; ++k3) + { + for(ck::index_t k4 = 0; k4 < K4; ++k4) + { + for(ck::index_t k5 = 0; k5 < K5; ++k5) + { + ComputeDataType v_a_compute_input; + ComputeDataType v_b_compute_input; + + // Simulate the possible casting when ComputeDataType is + // different than the A/B data types + if constexpr(NumDimK == 2) + { + v_a_compute_input = ck::type_convert( + arg.a_ms_ks_(m0, m1, k0, k1)); + v_b_compute_input = ck::type_convert( + arg.b_ns_ks_(n0, n1, k0, k1)); + } + else if constexpr(NumDimK == 6) + { + v_a_compute_input = ck::type_convert< + ComputeDataType>(arg.a_ms_ks_( + m0, m1, m2, m3, m4, m5, k0, k1, k2, k3, k4, k5)); + v_b_compute_input = ck::type_convert< + ComputeDataType>(arg.b_ns_ks_( + n0, n1, n2, n3, n4, n5, k0, k1, k2, k3, k4, k5)); + } + + AccDataType v_a; + AccDataType v_b; + + arg.a_element_op_( + v_a, ck::type_convert(v_a_compute_input)); + arg.b_element_op_( + v_b, ck::type_convert(v_b_compute_input)); + + v_acc += v_a * v_b; + } + } + } + } } } - arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert(v_acc); + if constexpr(NumDimK == 2) + { + arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert(v_acc); + } + else if constexpr(NumDimK == 6) + { + arg.c_ms_ns_(m0, m1, m2, m3, m4, m5, n0, n1, n2, n3, n4, n5) = + ck::type_convert(v_acc); + } }; - make_ParallelTensorFunctor(f_ms_ns, - arg.c_ms_ns_.mDesc.GetLengths()[0], - arg.c_ms_ns_.mDesc.GetLengths()[1], - arg.c_ms_ns_.mDesc.GetLengths()[2], - arg.c_ms_ns_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); + if constexpr(NumDimK == 2) + { + make_ParallelTensorFunctor(f_ms_ns, + arg.c_ms_ns_.mDesc.GetLengths()[0], + arg.c_ms_ns_.mDesc.GetLengths()[1], + 1, + 1, + 1, + 1, + arg.c_ms_ns_.mDesc.GetLengths()[2], + arg.c_ms_ns_.mDesc.GetLengths()[3], + 1, + 1, + 1, + 1)(std::thread::hardware_concurrency()); + } + else if constexpr(NumDimK == 6) + { + make_ParallelTensorFunctor(f_ms_ns, + arg.c_ms_ns_.mDesc.GetLengths()[0], + arg.c_ms_ns_.mDesc.GetLengths()[1], + arg.c_ms_ns_.mDesc.GetLengths()[2], + arg.c_ms_ns_.mDesc.GetLengths()[3], + arg.c_ms_ns_.mDesc.GetLengths()[4], + arg.c_ms_ns_.mDesc.GetLengths()[5], + arg.c_ms_ns_.mDesc.GetLengths()[6], + arg.c_ms_ns_.mDesc.GetLengths()[7], + arg.c_ms_ns_.mDesc.GetLengths()[8], + arg.c_ms_ns_.mDesc.GetLengths()[9], + arg.c_ms_ns_.mDesc.GetLengths()[10], + arg.c_ms_ns_.mDesc.GetLengths()[11])( + std::thread::hardware_concurrency()); + } return 0; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp index b67119ad1..84b976439 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include @@ -42,30 +42,31 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_kk_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, ComputeDataType>, // Small scalar per vector - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -78,33 +79,34 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_kn_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, // Small scalar per vector - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -117,33 +119,34 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_mk_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, // Small scalar per vector - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -156,33 +159,34 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_mn_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>, // Small scalar per vector - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 8>, 2, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -195,23 +199,24 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_f64_kk_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1, ComputeDataType> // clang-format on >; @@ -224,23 +229,24 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_f64_kn_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType> // clang-format on >; @@ -253,23 +259,24 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_f64_mk_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType> // clang-format on >; @@ -282,23 +289,24 @@ template + typename CDEElementwiseOp, + index_t NumDim = 2> using device_contraction_f64_mn_instance = std::tuple< // clang-format off //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type| //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType> + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>, + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index d06cab119..948c7ff3a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { +// 2D #ifdef CK_ENABLE_FP32 void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances); #endif // CK_ENABLE_FP16 +// 6D +#ifdef CK_ENABLE_FP32 +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP32 + +#ifdef CK_ENABLE_FP64 +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP64 + +#ifdef CK_ENABLE_FP16 +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 + +#ifdef CK_ENABLE_BF16 +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( + std::vector>>& instances); + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 // Contraction + Bilinear template ) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_FP32 #ifdef CK_ENABLE_FP64 @@ -496,6 +905,31 @@ struct DeviceOperationInstanceFactory) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_FP64 #ifdef CK_ENABLE_FP16 @@ -516,6 +950,20 @@ struct DeviceOperationInstanceFactory) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_FP16 #ifdef CK_ENABLE_BF16 @@ -536,6 +984,20 @@ struct DeviceOperationInstanceFactory) + { + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( + op_ptrs); + add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_BF16 return op_ptrs; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp index 8e994d61a..86becdb31 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -16,6 +16,7 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { +// 2D #ifdef CK_ENABLE_FP32 void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector>>& instances); #endif // CK_ENABLE_FP16 +// 6D +#ifdef CK_ENABLE_FP32 +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP32 + +#ifdef CK_ENABLE_FP64 +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP64 + +#ifdef CK_ENABLE_FP16 +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 + +#ifdef CK_ENABLE_BF16 +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( + std::vector>>& instances); + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 // Contraction + Scale template ) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_FP32 #ifdef CK_ENABLE_FP64 @@ -495,6 +904,31 @@ struct DeviceOperationInstanceFactory) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance( + op_ptrs); + } + else if constexpr(is_same_v) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_FP64 #ifdef CK_ENABLE_FP16 @@ -515,6 +949,20 @@ struct DeviceOperationInstanceFactory) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_FP16 #ifdef CK_ENABLE_BF16 @@ -535,6 +983,20 @@ struct DeviceOperationInstanceFactory) + { + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( + op_ptrs); + } + } } #endif // CK_ENABLE_BF16 return op_ptrs; diff --git a/library/include/ck/library/utility/host_tensor.hpp b/library/include/ck/library/utility/host_tensor.hpp index 816d83413..ddbd16ad9 100644 --- a/library/include/ck/library/utility/host_tensor.hpp +++ b/library/include/ck/library/utility/host_tensor.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -408,6 +408,37 @@ struct Tensor mDesc.GetLengths()[5])(num_thread); break; } + case 12: { + auto f = [&](auto i0, + auto i1, + auto i2, + auto i3, + auto i4, + auto i5, + auto i6, + auto i7, + auto i8, + auto i9, + auto i10, + auto i11) { + (*this)(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) = + g(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4], + mDesc.GetLengths()[5], + mDesc.GetLengths()[6], + mDesc.GetLengths()[7], + mDesc.GetLengths()[8], + mDesc.GetLengths()[9], + mDesc.GetLengths()[10], + mDesc.GetLengths()[11])(num_thread); + break; + } default: throw std::runtime_error("unspported dimension"); } } diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp rename to library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp index 56dc1d2c7..fbfaaa447 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/2D/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! @@ -31,7 +31,8 @@ using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_com F32, PassThrough, PassThrough, - Bilinear>; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( std::vector; + Bilinear, + 2>; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( std::vector + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp new file mode 100644 index 000000000..a2da04f70 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp new file mode 100644 index 000000000..7f076c985 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp new file mode 100644 index 000000000..c9e003fea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp new file mode 100644 index 000000000..655715118 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp new file mode 100644 index 000000000..4904360ea --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp new file mode 100644 index 000000000..069a8a744 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp new file mode 100644 index 000000000..605237138 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp new file mode 100644 index 000000000..16540cf11 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp new file mode 100644 index 000000000..2f1206bbf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp new file mode 100644 index 000000000..9bd03b37a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp new file mode 100644 index 000000000..55a1fb301 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp new file mode 100644 index 000000000..5394dc80d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp new file mode 100644 index 000000000..f5967e75e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp new file mode 100644 index 000000000..10713d741 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp new file mode 100644 index 000000000..1dfcfaae1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp new file mode 100644 index 000000000..112b5e1a0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = + device_contraction_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp new file mode 100644 index 000000000..f281a9fc7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = + device_contraction_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp new file mode 100644 index 000000000..9c79c15a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = + device_contraction_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp new file mode 100644 index 000000000..1a4425e45 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp new file mode 100644 index 000000000..52c8fb1c1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance = + device_contraction_f64_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp new file mode 100644 index 000000000..d023e438d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance = + device_contraction_f64_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp new file mode 100644 index 000000000..72bb9eda3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance = + device_contraction_f64_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp new file mode 100644 index 000000000..97ef213be --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance = + device_contraction_f64_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp new file mode 100644 index 000000000..06b1ce8fc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance = + device_contraction_f64_kk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp new file mode 100644 index 000000000..078facb33 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance = + device_contraction_f64_kn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp new file mode 100644 index 000000000..14eae08fd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance = + device_contraction_f64_mk_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp new file mode 100644 index 000000000..982e4dd5f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/6D/device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance = + device_contraction_f64_mn_instance; + +void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index a28c6717d..70e4bbfe5 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,43 +1,49 @@ # ONLY XDL_KERNELS set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) -# FP32 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp) - -# FP64 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp) - -# FP16 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp) - -# BF16 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp) +list(APPEND DIMS 2 6) + +foreach(idx IN LISTS DIMS) + set(PREFIX ${idx}D/device_contraction_bilinear_m${idx}_n${idx}_k${idx}) + + # FP32 + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) + + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp) + + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp) + + # FP64 + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) + + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp) + + # FP16 + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp) + + # BF16 + list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp) +endforeach() add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp rename to library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp index ace4d4a33..eb8f3641b 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/2D/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // This (ifndef) is a hack to use customized behavior for buffer load rather than using default // setting Don't use this hack unless absolutely necessary! @@ -31,7 +31,8 @@ using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32 F32, PassThrough, PassThrough, - Scale>; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( std::vector; + Scale, + 2>; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( std::vector + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp new file mode 100644 index 000000000..a7e5aad3b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp new file mode 100644 index 000000000..0f80e2e5d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp new file mode 100644 index 000000000..e335b2062 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp new file mode 100644 index 000000000..729fcf977 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp new file mode 100644 index 000000000..1c0b0084c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp new file mode 100644 index 000000000..4cae7d5f7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp new file mode 100644 index 000000000..263442e39 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp new file mode 100644 index 000000000..24965e16f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp new file mode 100644 index 000000000..3859cfebc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp new file mode 100644 index 000000000..f7c6f77c5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp new file mode 100644 index 000000000..1eb71160c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp new file mode 100644 index 000000000..5a32dd09a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp new file mode 100644 index 000000000..e773803f9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp new file mode 100644 index 000000000..c30460d0a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp new file mode 100644 index 000000000..f5cde5ac2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp new file mode 100644 index 000000000..d21b75645 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance = + device_contraction_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp new file mode 100644 index 000000000..d2ef6eb83 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance = + device_contraction_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp new file mode 100644 index 000000000..718c81c75 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance = + device_contraction_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp new file mode 100644 index 000000000..6ef8dae23 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance = + device_contraction_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f32_f32_f32_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp new file mode 100644 index 000000000..ff7d2ddb9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance = + device_contraction_f64_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp new file mode 100644 index 000000000..19ae89250 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance = + device_contraction_f64_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp new file mode 100644 index 000000000..4af5853c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance = + device_contraction_f64_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp new file mode 100644 index 000000000..65f0c84a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance = + device_contraction_f64_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp new file mode 100644 index 000000000..6bf5c3497 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance = + device_contraction_f64_kk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_kkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp new file mode 100644 index 000000000..4324914df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// k/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance = + device_contraction_f64_kn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_knn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp new file mode 100644 index 000000000..0ae7215e5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/k/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance = + device_contraction_f64_mk_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp new file mode 100644 index 000000000..2a56a88fb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/6D/device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] +// m/n/n/n are the fast changing dimension for A/B/D/E +using device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance = + device_contraction_f64_mn_instance; + +void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index b91de832e..dd36f88c4 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,44 +1,50 @@ # ONLY XDL_KERNELS set(DEVICE_CONTRACTION_SCALE_INSTANCES) -# FP32 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp) - -# FP64 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp) - -# FP16 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp) - -# BF16 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp) +list(APPEND DIMS 2 6) + +foreach(idx IN LISTS DIMS) + set(PREFIX ${idx}D/device_contraction_scale_m${idx}_n${idx}_k${idx}) + + # FP32 + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp) + + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp) + + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp) + + # FP64 + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp) + + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp) + + # FP16 + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp) + + # BF16 + list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp + ${PREFIX}_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp) +endforeach() add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES}) diff --git a/profiler/README.md b/profiler/README.md index f26c90d0b..a4daefba9 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -52,21 +52,23 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s #arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear) #arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16) #arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16) -#arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; +#arg4: Number of dimension for M, N and K (one for all) +#arg5: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]) -#arg5: verification (0: no; 1: yes) -#arg6: initialization (0: no init; 1: integer value; 2: decimal value) -#arg7: print tensor value (0: no; 1: yes) -#arg8: time kernel (0: no, 1: yes) -#arg9: alpha -#arg10: beta -#arg11 to 16: M0, M1, N0, N1, K0, K1 -#arg17 to 32: Strides for A, B, D and E (skip for default) +#arg6: verification (0: no; 1: yes) +#arg7: initialization (0: no init; 1: integer value; 2: decimal +# value) +#arg8: print tensor value (0: no; 1: yes) +#arg9: time kernel (0: no, 1: yes) +#arg10: alpha +#arg11: beta +#arg12 to 17/29: M0, M1, N0, N1, K0, K1 +#arg18/30 to 33/77: Strides for A, B, D and E (skip for default) -################ op datatype compute_datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 -./bin/ckProfiler contraction_bilinear 0 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 +################ op datatype compute_datatype num_dim layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 +./bin/ckProfiler contraction_bilinear 0 0 2 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ``` Result (MI100) diff --git a/profiler/include/profiler/profile_contraction_impl.hpp b/profiler/include/profiler/profile_contraction_impl.hpp index f6e4b3f39..604032a01 100644 --- a/profiler/include/profiler/profile_contraction_impl.hpp +++ b/profiler/include/profiler/profile_contraction_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" +#include "ck/library/utility/numeric.hpp" #include "ck/host_utility/io.hpp" @@ -34,7 +35,8 @@ using Scale = ck::tensor_operation::element_wise::Scale; using F32 = float; using F64 = double; -template a_ms_ks_lengths = {M[0], M[1], K[0], K[1]}; - const std::vector b_ns_ks_lengths = {N[0], N[1], K[0], K[1]}; - const std::vector e_ms_ns_lengths = {M[0], M[1], N[0], N[1]}; - const std::vector d_m_n_lengths = {M[0], M[1], N[0], N[1]}; + auto merge_dims = [](const std::vector& dims01, + const std::vector& dims23) { + std::vector dims_szt(dims01.begin(), dims01.end()); + dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end()); + return dims_szt; + }; + + const std::vector a_ms_ks_lengths = merge_dims(M, K); + const std::vector b_ns_ks_lengths = merge_dims(N, K); + const std::vector e_ms_ns_lengths = merge_dims(M, N); + const std::vector d_m_n_lengths = merge_dims(M, N); const auto a_element_op = AElementOp{}; const auto b_element_op = BElementOp{}; - constexpr ck::index_t NumDim = 2; - using DeviceOp = ck::tensor_operation::device::DeviceContractionMultipleD::value) { - for(size_t n0 = 0; n0 < e_m_n_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_m_n_host_result.mDesc.GetLengths()[3]; ++n1) - { - if constexpr(is_same::value) - { - cde_element_op(e_m_n_host_result(m0, m1, n0, n1), - c_m_n_host_result(m0, m1, n0, n1), - d_m_n(m0, m1, n0, n1)); - } - else if constexpr(is_same::value) - { - cde_element_op(e_m_n_host_result(m0, m1, n0, n1), - c_m_n_host_result(m0, m1, n0, n1)); - } - else - { - static_assert("Unsupported CDElementOp in contraction profiler."); - } - } - } + cde_element_op(self(idx), c_m_n_host_result(idx), d_m_n(idx)); } - } + else if constexpr(is_same::value) + { + cde_element_op(self(idx), c_m_n_host_result(idx)); + } + else + { + static_assert("Unsupported CDElementOp in contraction profiler."); + } + }); } std::string best_op_name; @@ -242,9 +237,12 @@ int profile_contraction_impl(ck::index_t do_verification, auto invoker_ptr = op_ptr->MakeInvokerPointer(); - auto nelems_m = M[0] * M[1]; - auto nelems_n = N[0] * N[1]; - auto nelems_k = K[0] * K[1]; + auto nelems_m = ck::accumulate_n( + a_ms_ks_lengths.begin(), NumDimMNK, 1, std::multiplies<>{}); + auto nelems_n = ck::accumulate_n( + b_ns_ks_lengths.begin(), NumDimMNK, 1, std::multiplies<>{}); + auto nelems_k = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimMNK, NumDimMNK, 1, std::multiplies<>{}); if(op_ptr->IsSupportedArgument(argument_ptr.get())) { diff --git a/profiler/include/profiler/profile_contraction_utils.hpp b/profiler/include/profiler/profile_contraction_utils.hpp index 05ec7daf9..adfd98a37 100644 --- a/profiler/include/profiler/profile_contraction_utils.hpp +++ b/profiler/include/profiler/profile_contraction_utils.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -48,14 +48,36 @@ inline void collect_index_params(char* argv[], // Defualt strides for row-major: {Dim1 * Dim2 * Dim3, Dim2 * Dim3, Dim3, 1} // Defualt strides for column-major: {Dim1, 1, Dim0 * Dim1 * Dim3, Dim0 * Dim1} + +// M1, 1, M0 * M1 * K1, M0 * M1 +// K0, K1, M0, M1 inline void assign_default_strides(Row, std::vector& strides, std::vector dims) { - strides = {dims[1] * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1}; + ck::index_t stride = 1; + for(ck::index_t s = strides.size() - 1; s >= 0; s--) + { + strides[s] = stride; + stride *= dims[s]; + } } inline void assign_default_strides(Col, std::vector& strides, std::vector dims) { - strides = {dims[1], 1, dims[0] * dims[1] * dims[3], dims[0] * dims[1]}; + // Assign second half of strides + ck::index_t stride = 1; + for(ck::index_t s = strides.size() / 2 - 1; s >= 0; s--) + { + strides[s] = stride; + stride *= dims[s]; + } + + // Assign first half of strides + for(ck::index_t s = strides.size() - 1; s > static_cast(strides.size()) / 2 - 1; + s--) + { + strides[s] = stride; + stride *= dims[s]; + } } diff --git a/profiler/src/profile_contraction_bilinear.cpp b/profiler/src/profile_contraction_bilinear.cpp index 8cb34ccc6..990e1e119 100644 --- a/profiler/src/profile_contraction_bilinear.cpp +++ b/profiler/src/profile_contraction_bilinear.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -19,7 +19,8 @@ static void print_helper_msg() std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + << "arg4: Number of dimension for M, N and K (one for all)\n" + << "arg5: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" @@ -27,23 +28,23 @@ static void print_helper_msg() "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" - << "arg5: verification (0: no; 1: yes)\n" - << "arg6: initialization (0: no init; 1: integer value; 2: decimal " + << "arg6: verification (0: no; 1: yes)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" - << "arg7: print tensor value (0: no; 1: yes)\n" - << "arg8: time kernel (0: no, 1: yes)\n" - << "arg9: alpha\n" - << "arg10: beta\n" - << "arg11 to 16: M0, M1, N0, N1, K0, K1\n" - << "arg17 to 32: Strides for A, B, D and E (skip for default)\n" + << "arg8: print tensor value (0: no; 1: yes)\n" + << "arg9: time kernel (0: no, 1: yes)\n" + << "arg10: alpha\n" + << "arg11: beta\n" + << "arg12 to 17/29: M0, M1, N0, N1, K0, K1\n" + << "arg18/30 to 33/77: Strides for A, B, D and E (skip for default)\n" << std::endl; } int profile_contraction_bilinear(int argc, char* argv[]) { - const bool default_strides = argc == 17; + const bool default_strides = argc == 18 || 30; - if(argc != 33 && argc != 17) + if(argc != 34 && argc != 78 && !default_strides) { print_helper_msg(); exit(1); @@ -51,32 +52,33 @@ int profile_contraction_bilinear(int argc, char* argv[]) const auto data_type = static_cast(std::stoi(argv[2])); const auto compute_data_type = static_cast(std::stoi(argv[3])); - const auto layout = static_cast(std::stoi(argv[4])); - const bool do_verification = std::stoi(argv[5]); - const ck::index_t init_method = std::stoi(argv[6]); - const bool do_log = std::stoi(argv[7]); - const bool time_kernel = std::stoi(argv[8]); - const float alpha = std::stof(argv[9]); - const float beta = std::stof(argv[10]); + const ck::index_t NumDimMNK = std::stoi(argv[4]); + const auto layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const ck::index_t init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + const float alpha = std::stof(argv[10]); + const float beta = std::stof(argv[11]); std::vector M; std::vector N; std::vector K; - const ck::index_t dims_arg_num = 11; - collect_index_params(argv, M, dims_arg_num, 2); - collect_index_params(argv, N, dims_arg_num + 2, 2); - collect_index_params(argv, K, dims_arg_num + 4, 2); - - std::vector StridesA; - std::vector StridesB; - std::vector StridesE; - std::vector StridesD; + const ck::index_t dims_arg_num = 12; + collect_index_params(argv, M, dims_arg_num, NumDimMNK); + collect_index_params(argv, N, dims_arg_num + NumDimMNK, NumDimMNK); + collect_index_params(argv, K, dims_arg_num + NumDimMNK * 2, NumDimMNK); + + std::vector StridesA(NumDimMNK * 2); + std::vector StridesB(NumDimMNK * 2); + std::vector StridesE(NumDimMNK * 2); + std::vector StridesD(NumDimMNK * 2); if(!default_strides) { - collect_index_params(argv, StridesA, dims_arg_num + 6, 4); - collect_index_params(argv, StridesB, dims_arg_num + 10, 4); - collect_index_params(argv, StridesE, dims_arg_num + 14, 4); - collect_index_params(argv, StridesD, dims_arg_num + 18, 4); + collect_index_params(argv, StridesA, dims_arg_num + NumDimMNK * 3, NumDimMNK * 2); + collect_index_params(argv, StridesB, dims_arg_num + NumDimMNK * 5, NumDimMNK * 2); + collect_index_params(argv, StridesE, dims_arg_num + NumDimMNK * 7, NumDimMNK * 2); + collect_index_params(argv, StridesD, dims_arg_num + NumDimMNK * 9, NumDimMNK * 2); } using F16 = ck::half_t; @@ -95,31 +97,71 @@ int profile_contraction_bilinear(int argc, char* argv[]) if(default_strides) { - assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); - assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); + auto merge_dims = [](const std::vector& dims01, + const std::vector& dims23) { + std::vector dims_szt(dims01.begin(), dims01.end()); + dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end()); + return dims_szt; + }; + + assign_default_strides(a_layout, StridesA, merge_dims(M, K)); + assign_default_strides(b_layout, StridesB, merge_dims(N, K)); + assign_default_strides(cde_layout, StridesE, merge_dims(M, N)); + assign_default_strides(cde_layout, StridesD, merge_dims(M, N)); + } + if(NumDimMNK == 2) + { + bool pass = ck::profiler::profile_contraction_impl<2, + ALayout, + BLayout, + CDELayout, + DataType, + ComputeDataType, + ck::Tuple, + Bilinear>(do_verification, + init_method, + do_log, + time_kernel, + Bilinear{alpha, beta}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; + } + else if(NumDimMNK == 6) + { + bool pass = ck::profiler::profile_contraction_impl<6, + ALayout, + BLayout, + CDELayout, + DataType, + ComputeDataType, + ck::Tuple, + Bilinear>(do_verification, + init_method, + do_log, + time_kernel, + Bilinear{alpha, beta}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; + } + else + { + throw std::runtime_error("Not supported NumDimMNK"); + return false; } - bool pass = ck::profiler::profile_contraction_impl, - Bilinear>(do_verification, - init_method, - do_log, - time_kernel, - Bilinear{alpha, beta}, - M, - N, - K, - StridesA, - StridesB, - StridesE, - StridesD); - - return pass; }; auto run_profile_for_datatype = [&](auto type, auto compute_type) { diff --git a/profiler/src/profile_contraction_scale.cpp b/profiler/src/profile_contraction_scale.cpp index ca9c19998..85252eaa3 100644 --- a/profiler/src/profile_contraction_scale.cpp +++ b/profiler/src/profile_contraction_scale.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -19,7 +19,8 @@ static void print_helper_msg() std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + << "arg4: Number of dimension for M, N and K (one for all)\n" + << "arg5: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" @@ -27,22 +28,22 @@ static void print_helper_msg() "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" - << "arg5: verification (0: no; 1: yes)\n" - << "arg6: initialization (0: no init; 1: integer value; 2: decimal " + << "arg6: verification (0: no; 1: yes)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" - << "arg7: print tensor value (0: no; 1: yes)\n" - << "arg8: time kernel (0: no, 1: yes)\n" - << "arg9: alpha\n" - << "arg10 to 15: M0, M1, N0, N1, K0, K1\n" - << "arg16 to 31: Strides for A, B, D and E (skip for default)\n" + << "arg8: print tensor value (0: no; 1: yes)\n" + << "arg9: time kernel (0: no, 1: yes)\n" + << "arg10: alpha\n" + << "arg11 to 16/28: M0, M1, N0, N1, K0, K1\n" + << "arg17/29 to 32/63: Strides for A, B, E (skip for default)\n" << std::endl; } int profile_contraction_scale(int argc, char* argv[]) { - const bool default_strides = argc == 16; + const bool default_strides = argc == 17 || argc == 29; - if(argc != 32 && argc != 16) + if(argc != 29 && argc != 65 && !default_strides) { print_helper_msg(); exit(1); @@ -50,31 +51,30 @@ int profile_contraction_scale(int argc, char* argv[]) const auto data_type = static_cast(std::stoi(argv[2])); const auto compute_data_type = static_cast(std::stoi(argv[3])); - const auto layout = static_cast(std::stoi(argv[4])); - const bool do_verification = std::stoi(argv[5]); - const ck::index_t init_method = std::stoi(argv[6]); - const bool do_log = std::stoi(argv[7]); - const bool time_kernel = std::stoi(argv[8]); - const float alpha = std::stof(argv[9]); + const ck::index_t NumDimMNK = std::stoi(argv[4]); + const auto layout = static_cast(std::stoi(argv[5])); + const bool do_verification = std::stoi(argv[6]); + const ck::index_t init_method = std::stoi(argv[7]); + const bool do_log = std::stoi(argv[8]); + const bool time_kernel = std::stoi(argv[9]); + const float alpha = std::stof(argv[10]); std::vector M; std::vector N; std::vector K; - const ck::index_t dims_arg_num = 10; - collect_index_params(argv, M, dims_arg_num, 2); - collect_index_params(argv, N, dims_arg_num + 2, 2); - collect_index_params(argv, K, dims_arg_num + 4, 2); - - std::vector StridesA; - std::vector StridesB; - std::vector StridesE; - std::vector StridesD; + const ck::index_t dims_arg_num = 11; + collect_index_params(argv, M, dims_arg_num, NumDimMNK); + collect_index_params(argv, N, dims_arg_num + NumDimMNK, NumDimMNK); + collect_index_params(argv, K, dims_arg_num + NumDimMNK * 2, NumDimMNK); + + std::vector StridesA(NumDimMNK * 2); + std::vector StridesB(NumDimMNK * 2); + std::vector StridesE(NumDimMNK * 2); if(!default_strides) { - collect_index_params(argv, StridesA, dims_arg_num + 6, 4); - collect_index_params(argv, StridesB, dims_arg_num + 10, 4); - collect_index_params(argv, StridesE, dims_arg_num + 14, 4); - collect_index_params(argv, StridesD, dims_arg_num + 18, 4); + collect_index_params(argv, StridesA, dims_arg_num + NumDimMNK * 3, NumDimMNK * 2); + collect_index_params(argv, StridesB, dims_arg_num + NumDimMNK * 5, NumDimMNK * 2); + collect_index_params(argv, StridesE, dims_arg_num + NumDimMNK * 7, NumDimMNK * 2); } using F16 = ck::half_t; @@ -93,32 +93,71 @@ int profile_contraction_scale(int argc, char* argv[]) if(default_strides) { - assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); - assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); + auto merge_dims = [](const std::vector& dims01, + const std::vector& dims23) { + std::vector dims_szt(dims01.begin(), dims01.end()); + dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end()); + return dims_szt; + }; + + assign_default_strides(a_layout, StridesA, merge_dims(M, K)); + assign_default_strides(b_layout, StridesB, merge_dims(N, K)); + assign_default_strides(cde_layout, StridesE, merge_dims(M, N)); } - bool pass = ck::profiler::profile_contraction_impl, - Scale>(do_verification, - init_method, - do_log, - time_kernel, - Scale{alpha}, - M, - N, - K, - StridesA, - StridesB, - StridesE, - StridesD); - - return pass; + if(NumDimMNK == 2) + { + bool pass = ck::profiler::profile_contraction_impl<2, + ALayout, + BLayout, + CDELayout, + DataType, + ComputeDataType, + ck::Tuple<>, + Scale>(do_verification, + init_method, + do_log, + time_kernel, + Scale{alpha}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesE); + + return pass; + } + else if(NumDimMNK == 6) + { + bool pass = ck::profiler::profile_contraction_impl<6, + ALayout, + BLayout, + CDELayout, + DataType, + ComputeDataType, + ck::Tuple<>, + Scale>(do_verification, + init_method, + do_log, + time_kernel, + Scale{alpha}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesE); + + return pass; + } + else + { + throw std::runtime_error("Not supported NumDimMNK"); + return false; + } }; auto run_profile_for_datatype = [&](auto type, auto compute_type) { diff --git a/test/contraction/test_contraction_interface_xdl.cpp b/test/contraction/test_contraction_interface_xdl.cpp index d6b290e20..58232d209 100644 --- a/test/contraction/test_contraction_interface_xdl.cpp +++ b/test/contraction/test_contraction_interface_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -125,18 +125,6 @@ class ContractionDeviceOpWrapper } }; -TEST(TestContractionInterface, IncorrectNumDims) -{ - std::vector> Dims = {{4, 4}, {4, 4, 4, 4}, {4, 4, 4, 4, 4, 4}}; - std::vector> Strides = {{1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}}; - ContractionDeviceOpWrapper wrapper_1d; - ContractionDeviceOpWrapper wrapper_2d; - ContractionDeviceOpWrapper wrapper_3d; - EXPECT_FALSE(wrapper_1d.IsSupportedInstance(Dims[0], Strides[0])); - EXPECT_TRUE(wrapper_2d.IsSupportedInstance(Dims[1], Strides[1])); - EXPECT_FALSE(wrapper_3d.IsSupportedInstance(Dims[2], Strides[2])); -} - TEST(TestContractionInterface, IncorrectDataTypes) { std::vector Dims = {4, 4, 4, 4}; diff --git a/test/contraction/test_contraction_xdl.cpp b/test/contraction/test_contraction_xdl.cpp index 958d5be38..c84375b1d 100644 --- a/test/contraction/test_contraction_xdl.cpp +++ b/test/contraction/test_contraction_xdl.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -23,8 +23,11 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Scale = ck::tensor_operation::element_wise::Scale; +template struct Dimensions { + constexpr static ck::index_t NumDimMNK = NDims; + std::vector M; std::vector N; std::vector K; @@ -42,53 +45,58 @@ class TestContraction : public ::testing::Test using ComputeDataType = std::tuple_element_t<5, Tuple>; using CDElementOp = std::tuple_element_t<6, Tuple>; - std::vector dimension_list = {{{32, 32}, {32, 32}, {32, 32}}, - {{16, 16}, {32, 32}, {16, 16}}}; - std::vector init_methods = {1, 2}; std::unique_ptr p_cd_element_op; - void Run() + template + void Run(Dimensions dimension_params) { - for(auto& dimension_params : dimension_list) + constexpr ck::index_t NumDimMNK = ck::remove_cvref_t::NumDimMNK; + + std::vector StridesA(2 * NumDim); + std::vector StridesB(2 * NumDim); + std::vector StridesC(2 * NumDim); + std::vector StridesD(2 * NumDim); + + const auto& M = dimension_params.M; + const auto& N = dimension_params.N; + const auto& K = dimension_params.K; + + auto merge_dims = [](const std::vector& dims01, + const std::vector& dims23) { + std::vector dims_szt(dims01.begin(), dims01.end()); + dims_szt.insert(dims_szt.end(), dims23.begin(), dims23.end()); + return dims_szt; + }; + + assign_default_strides(ALayout{}, StridesA, merge_dims(M, K)); + assign_default_strides(BLayout{}, StridesB, merge_dims(N, K)); + assign_default_strides(CDLayout{}, StridesC, merge_dims(M, N)); + assign_default_strides(CDLayout{}, StridesD, merge_dims(M, N)); + + for(const ck::index_t init_method : init_methods) { - std::vector StridesA; - std::vector StridesB; - std::vector StridesC; - std::vector StridesD; - - const auto& M = dimension_params.M; - const auto& N = dimension_params.N; - const auto& K = dimension_params.K; - - assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]}); - assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]}); - - for(const ck::index_t init_method : init_methods) - { - bool pass = - ck::profiler::profile_contraction_impl(true /*do_verification*/, - init_method, - false /*do_logs*/, - false /*time_kernel*/, - *p_cd_element_op, - dimension_params.M, - dimension_params.N, - dimension_params.K, - StridesA, - StridesB, - StridesC, - StridesD); - EXPECT_TRUE(pass); - } + bool pass = + ck::profiler::profile_contraction_impl(true /*do_verification*/, + init_method, + false /*do_logs*/, + false /*time_kernel*/, + *p_cd_element_op, + dimension_params.M, + dimension_params.N, + dimension_params.K, + StridesA, + StridesB, + StridesC, + StridesD); + EXPECT_TRUE(pass); } } }; @@ -122,17 +130,31 @@ TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); TYPED_TEST(TestContractionBilinear, bilinear) { this->p_cd_element_op = std::make_unique(1.f, 1.f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); + this->p_cd_element_op = std::make_unique(-0.5f, 0.5f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); } TYPED_TEST(TestContractionScale, scale) { this->p_cd_element_op = std::make_unique(1.f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); + this->p_cd_element_op = std::make_unique(0.5f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); } template @@ -165,15 +187,29 @@ TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecis TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear) { this->p_cd_element_op = std::make_unique(1.f, 1.f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); + this->p_cd_element_op = std::make_unique(-0.5f, 0.5f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); } TYPED_TEST(TestContractionScaleMixedPrecision, scale) { this->p_cd_element_op = std::make_unique(1.f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); + this->p_cd_element_op = std::make_unique(0.5f); - this->Run(); + this->template Run<6>({{2, 3, 2, 3, 2, 3}, {2, 3, 2, 3, 2, 3}, {2, 2, 2, 2, 2, 4}}); + this->template Run<6>({{1, 1, 1, 3, 2, 3}, {1, 1, 1, 3, 2, 3}, {1, 1, 1, 2, 2, 4}}); + this->template Run<2>({{16, 8}, {16, 8}, {16, 8}}); + this->template Run<2>({{8, 16}, {16, 8}, {8, 16}}); } -- GitLab From 381d44aa60501098baef69bc78add2c24adb4ba4 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 9 Apr 2024 21:32:02 -0500 Subject: [PATCH 23/63] add yigex (#1230) --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 37407cebf..01e3bee0b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ -* @zjing14 @junliume @illsilin @carlushuang @aosewski +* @zjing14 @junliume @illsilin @carlushuang @aosewski @yigex # Documentation files docs/* @ROCm/rocm-documentation *.md @ROCm/rocm-documentation -- GitLab From b2735caf465fd33dc900169f0be5588fe07cfa55 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 11 Apr 2024 07:39:44 -0700 Subject: [PATCH 24/63] Bump rocm-docs-core from 0.38.0 to 0.38.1 in /docs/sphinx (#1232) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.38.0 to 0.38.1. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.38.0...v0.38.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 2b28fcdd3..a85454243 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.38.0 +rocm-docs-core==0.38.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 335d6e5e0..801726ed6 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -111,7 +111,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.38.0 +rocm-docs-core==0.38.1 # via -r requirements.in six==1.16.0 # via -- GitLab From bbefc12a261b0cbd6efb3e036db1d9a0c0fefe4b Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 11 Apr 2024 10:35:00 -0500 Subject: [PATCH 25/63] Add instances for conv_scale with bf8@fp8->fp8 (#1231) * Add instances * Add example * Add profiler mode * Add client example --- client_example/16_convnd_fwd/CMakeLists.txt | 3 + .../16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp | 50 +++++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + .../09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp | 83 +++++++++++++++++++ .../device_grouped_conv_fwd_xdl_instance.hpp | 36 ++++++++ .../gpu/grouped_convolution_forward.hpp | 8 ++ .../gpu/grouped_convolution_forward_xdl.inc | 18 ++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 2 + ..._ndhwgc_gkzyxc_ndhwgk_bf8_fp8_instance.cpp | 54 ++++++++++++ profiler/src/profile_grouped_conv_fwd.cpp | 8 +- 10 files changed, 262 insertions(+), 1 deletion(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_fp8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 808693b63..23311b402 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -20,6 +20,9 @@ endif() if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp) target_link_libraries(client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp) + target_link_libraries(client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations) endif() if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp new file mode 100644 index 000000000..b195d87bb --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_bf8_fp8.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +using AComputeType = ck::bf8_t; +using BComputeType = ck::f8_t; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 778e81872..8a295d14c 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -7,6 +7,7 @@ add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) add_example_executable(example_convnd_fwd_xdl_fp16_comp_fp8 convnd_fwd_xdl_fp16_comp_fp8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8_bf8 convnd_fwd_xdl_fp8_bf8.cpp) +add_example_executable(example_convnd_fwd_xdl_bf8_fp8 convnd_fwd_xdl_bf8_fp8.cpp) add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp) add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp) add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp new file mode 100644 index 000000000..9eba00993 --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8_fp8.cpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using AComputeType = ck::bf8_t; +using BComputeType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + AComputeType, + BComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 40878e4f0..af79eefa1 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -326,6 +326,42 @@ using device_grouped_conv_fwd_xdl_f8_bf8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_bf8_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|AComputeType|BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8, F8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index e61ec2828..8602a82ff 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -301,6 +301,14 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index 691414ebc..e627d428d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -369,6 +369,24 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_bf8_instances( BF8>>>& instances); #endif +#if(defined(CK_ENABLE_FP8) && defined(CK_ENABLE_BF8)) +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector>>& instances); +#endif + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 50a6ec9a4..579bea00d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -44,6 +44,8 @@ endif() if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_bf8_instance.cpp) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_fp8_instance.cpp) endif() add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_fp8_instance.cpp new file mode 100644 index 000000000..08c45d7dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_fp8_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_f8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index a847999b5..577efafb1 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -26,6 +26,7 @@ enum struct ConvDataType F8_F8_F8, // 4 BF8_BF8_F8, // 5 F8_BF8_F8, // 6 + BF8_F8_F8, // 7 }; #define OP_NAME "grouped_conv_fwd" @@ -42,7 +43,8 @@ static void print_helper_msg() << " 3: Input int8, Weight int8, Output int8\n" << " 4: Input fp8, Weight fp8, Output fp8\n" << " 5: Input bf8, Weight bf8, Output fp8\n" - << " 6: Input fp8, Weight bf8, Output fp8)\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -281,6 +283,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{}); } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{}); + } } std::cout << "this data_type & layout is not implemented" << std::endl; -- GitLab From d7f05fb996a55a37bb005f27d0c09c7595e5aee2 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 11 Apr 2024 16:40:45 -0700 Subject: [PATCH 26/63] [HotFix] pass XDL and WMMA macros to libs that use CK (#1234) --- include/ck/config.h.in | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/include/ck/config.h.in b/include/ck/config.h.in index 174834475..eb9049b59 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -104,6 +104,20 @@ #cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@ #endif +// +// CK kernels which support XDL (MI series) +// +#ifndef CK_USE_XDL +#cmakedefine CK_USE_XDL @CK_USE_XDL@ +#endif + +// +// CK Kernels which support WMMA (recent Navi series) +// +#ifndef CK_USE_WMMA +#cmakedefine CK_USE_WMMA @CK_USE_WMMA@ +#endif + // clang-format on #endif // CK_CONFIG_H_IN -- GitLab From 7cdf5a96d284a2d4c9fbb4728040b2f7537e80b3 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 12 Apr 2024 10:55:02 -0700 Subject: [PATCH 27/63] Update the config.h after the CK_USE_XDL/WMMA are set. (#1236) * pass XDL and WMMA macros to libs that use CK * update config.h after XDL and WMMA macros get set --- CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c77f520a..7b9721d05 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,9 +81,6 @@ endif() include(getopt) -# CK config file to record supported datatypes, etc. -configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h) - # CK version file to record release version as well as git commit hash find_package(Git REQUIRED) execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE) @@ -159,6 +156,9 @@ else() set(CK_USE_WMMA "ON") endif() +# CK config file to record supported datatypes, etc. +configure_file(include/ck/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/config.h) + find_package(hip) # No assumption that HIP kernels are launched with uniform block size for backward compatibility # SWDEV-413293 and https://reviews.llvm.org/D155213 -- GitLab From f83e9701e921464bc845a37c77cd491d267d19ef Mon Sep 17 00:00:00 2001 From: Haocong WANG Date: Sun, 14 Apr 2024 10:03:18 +0800 Subject: [PATCH 28/63] [GEMM] Gemm universal device operation (#1154) * Optimize GEMM on MI200/300: 1. Add new blockwise gemm pipeline 2. Add irregular splitk intances * clang format + typo fix * Fix a bug * initial commit * Add more instances to irregular splitk * blkgemm pipeline v1~4 prototype * Sanity Checked. Known issue: 1. Poor performance of splitk 2. Register spill on blkgemmpipeline v3 * Sanity and Performance fix: 1. fix a bug related to sanity in grouped b2c mapping 2. fix a bug related to sanity and performance in splitk offset * Sanity and API update: 1. Remove prefetch stage 2. Fix valid check bug 3, Add first gemm_universal instance into ckProfiler * Add NN instances for gemm universal * 1. Add NT instances for gemm_universal 2. Fix a bug about Kpadding in gemm_universal * Fix a bug regarding padding Odd K number * remove kernel print * Fix KPadding bug... * Update safety check * another try to fix kpadding.. * Sanity checked * new instances.. * clang format+typo fix * remove clang format script's change * Add non-hotloop compile option * 1. Add fp16xfp8 example 2. pull packed convert f8 from pr1150 * Some miscs.. opt and fix * Add pipeline description docs * Split universal gemm instance library to cut profiler compiling time * uncomment cmakefile * Fix a bug caused by blockwise_gemm_pipe_v2 * reduce default splitk to 1 * Add 224x256x64 tile size * update, including: 1. Experiment pipeline 5~7 2. Optimization for pipeline 4 3. Organized instance library * temp save * temp save * Permuted lds layout, sanity and function checked * clang format * Move OOB check from RunRead to RunWrite, for better software pipeline. TODO: agpr spill when NN layout * clangformat * A/B splitpipe scheduler for v3 * Fix two bugs * bug fix * fix a bug in oob check * Example for mixed fp16_fp8 gemm * Clean experimental code blocks * Add mixed precision gemm into profiler * tempsave * optimize m/n major lds layout * Add RRR GEMM mixed precision instances * Optimize f8 matrix transpose * Add test_gemm_universal * A/B spilt schedule for blkpip v5 * Take ds_read2 into iglp scheduling scheme * format * fixed cmake * Add llvm-option into CI cmake flag --------- Co-authored-by: Jing Zhang --- Jenkinsfile | 14 +- example/01_gemm/CMakeLists.txt | 7 + example/01_gemm/common.hpp | 62 + example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp | 53 + example/01_gemm/gemm_xdl_fp16_v3.cpp | 48 + example/01_gemm/gemm_xdl_fp8_v3.cpp | 48 + example/01_gemm/run_gemm_example_v2.inc | 211 ++ .../multi_index_transform.hpp | 85 + .../multi_index_transform_helper.hpp | 6 + .../block/blockwise_gemm_pipeline_xdlops.hpp | 4 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 354 ++++ ...lockwise_gemm_pipeline_xdlops_selector.hpp | 167 ++ .../blockwise_gemm_pipeline_xdlops_v1.hpp | 732 +++++++ .../blockwise_gemm_pipeline_xdlops_v2.hpp | 1154 ++++++++++ .../blockwise_gemm_pipeline_xdlops_v3.hpp | 439 ++++ .../blockwise_gemm_pipeline_xdlops_v4.hpp | 597 ++++++ .../blockwise_gemm_pipeline_xdlops_v5.hpp | 667 ++++++ .../gpu/device/device_gemm_v2.hpp | 43 + .../impl/device_gemm_xdl_cshuffle_v3.hpp | 687 ++++++ .../element/unary_element_wise_operation.hpp | 13 + .../gpu/grid/block_to_ctile_map.hpp | 184 +- .../grid/gridwise_gemm_xdl_cshuffle_v3.hpp | 1871 +++++++++++++++++ .../threadwise_tensor_slice_transfer_v3r1.hpp | 95 +- include/ck/utility/blkgemmpipe_scheduler.hpp | 104 + include/ck/utility/data_type.hpp | 7 + include/ck/utility/synchronization.hpp | 10 +- include/ck/utility/transpose_vectors.hpp | 79 + include/ck/utility/type.hpp | 2 + .../cpu/reference_gemm.hpp | 2 +- .../gpu/gemm_universal.hpp | 505 +++++ .../library/utility/host_tensor_generator.hpp | 26 + .../gpu/gemm_universal/CMakeLists.txt | 70 + ...emm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp | 91 + ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 23 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 23 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 23 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 23 + ...6_f16_mk_kn_mn_mem_v1_default_instance.cpp | 24 + ..._f16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 24 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 24 + ...6_f16_mk_kn_mn_mem_v2_default_instance.cpp | 24 + ..._f16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 24 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 27 + ...emm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp | 98 + ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 23 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 23 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 23 + ...6_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 26 + ...6_f16_mk_nk_mn_mem_v1_default_instance.cpp | 24 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 24 + ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 24 + ...6_f16_mk_nk_mn_mem_v2_default_instance.cpp | 24 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 24 + ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 24 + ...gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp | 80 + ..._f8_f16_mk_kn_mn_comp_default_instance.cpp | 23 + ...f8_f16_mk_kn_mn_comp_kpadding_instance.cpp | 23 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 23 + ...8_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 23 + ...8_f16_mk_kn_mn_mem_v1_default_instance.cpp | 24 + ..._f16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 24 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 24 + ...8_f16_mk_kn_mn_mem_v2_default_instance.cpp | 24 + ..._f16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 24 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 27 + ...gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp | 86 + ..._f8_f16_mk_nk_mn_comp_default_instance.cpp | 23 + ...f8_f16_mk_nk_mn_comp_kpadding_instance.cpp | 23 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 23 + ...8_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 26 + ...8_f16_mk_nk_mn_mem_v1_default_instance.cpp | 24 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 24 + ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 24 + ...8_f16_mk_nk_mn_mem_v2_default_instance.cpp | 24 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 24 + ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 24 + ...gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp | 79 + ...f16_f16_mk_kn_mn_comp_default_instance.cpp | 23 + ...16_f16_mk_kn_mn_comp_kpadding_instance.cpp | 23 + ..._f16_mk_kn_mn_comp_mnkpadding_instance.cpp | 23 + ...6_f16_mk_kn_mn_comp_mnpadding_instance.cpp | 23 + ...6_f16_mk_kn_mn_mem_v1_default_instance.cpp | 24 + ..._f16_mk_kn_mn_mem_v1_kpadding_instance.cpp | 24 + ...16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp | 24 + ...6_f16_mk_kn_mn_mem_v2_default_instance.cpp | 24 + ..._f16_mk_kn_mn_mem_v2_kpadding_instance.cpp | 24 + ...16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp | 27 + ...gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp | 86 + ...f16_f16_mk_nk_mn_comp_default_instance.cpp | 23 + ...16_f16_mk_nk_mn_comp_kpadding_instance.cpp | 23 + ..._f16_mk_nk_mn_comp_mnkpadding_instance.cpp | 23 + ...6_f16_mk_nk_mn_comp_mnpadding_instance.cpp | 26 + ...6_f16_mk_nk_mn_mem_v1_default_instance.cpp | 24 + ..._f16_mk_nk_mn_mem_v1_kpadding_instance.cpp | 24 + ...16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp | 24 + ...6_f16_mk_nk_mn_mem_v2_default_instance.cpp | 24 + ..._f16_mk_nk_mn_mem_v2_kpadding_instance.cpp | 24 + ...16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp | 24 + .../profiler/profile_gemm_universal_impl.hpp | 299 +++ profiler/src/CMakeLists.txt | 2 + profiler/src/profile_gemm_universal.cpp | 164 ++ script/cmake-ck-dev.sh | 2 +- test/CMakeLists.txt | 1 + test/gemm_universal/CMakeLists.txt | 4 + .../test_gemm_universal_ut_cases.inc | 113 + .../test_gemm_universal_util.hpp | 91 + .../test_gemm_universal_xdl.cpp | 54 + 107 files changed, 10907 insertions(+), 123 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp create mode 100644 example/01_gemm/gemm_xdl_fp16_v3.cpp create mode 100644 example/01_gemm/gemm_xdl_fp8_v3.cpp create mode 100644 example/01_gemm/run_gemm_example_v2.inc create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp create mode 100644 include/ck/utility/blkgemmpipe_scheduler.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_universal_impl.hpp create mode 100644 profiler/src/profile_gemm_universal.cpp create mode 100644 test/gemm_universal/CMakeLists.txt create mode 100644 test/gemm_universal/test_gemm_universal_ut_cases.inc create mode 100644 test/gemm_universal/test_gemm_universal_util.hpp create mode 100644 test/gemm_universal/test_gemm_universal_xdl.cpp diff --git a/Jenkinsfile b/Jenkinsfile index 654c7274f..f28d1b939 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -824,7 +824,7 @@ pipeline { -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx908;gfx90a" \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check""" + -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j check""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -848,12 +848,12 @@ pipeline { setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DCMAKE_EXE_LINKER_FLAGS=" -L ${env.WORKSPACE}/script -T hip_fatbin_insert " \ - -DCMAKE_CXX_FLAGS=" -O3 " """ + -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -868,12 +868,12 @@ pipeline { } agent{ label rocmnode("gfx942") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') @@ -888,12 +888,12 @@ pipeline { } agent{ label rocmnode("gfx908 || gfx90a") } environment{ - setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a" -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx908;gfx90a" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ - -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + -DCMAKE_CXX_FLAGS=" -Xclang -mllvm -Xclang -enable-post-misched=0 -O3 " .. && make -j """ } steps{ Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 39e3f2a2b..0d3e6287d 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -22,6 +22,13 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16) add_example_executable(example_gemm_xdl_fp16_v2 gemm_xdl_fp16_v2.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v2) +add_example_executable(example_gemm_xdl_fp16_v3 gemm_xdl_fp16_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) +add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) +add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) + add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index eb281af7b..ef87d9c2f 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -46,6 +46,19 @@ struct ProblemSizeStreamK final ck::index_t NumSKBlocks = -1; }; +struct ProblemSizeSplitK final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t KBatch = 1; +}; + struct ExecutionConfig final { bool do_verification = true; @@ -158,3 +171,52 @@ bool parse_cmd_args(int argc, return true; } + +template <> +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeSplitK& problem_size, + ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc >= 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + + if(argc >= 11) + { + problem_size.KBatch = std::stoi(argv[10]); + } + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg10: KBatch" << std::endl; + return false; + } + + return true; +} diff --git a/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp new file mode 100644 index 000000000..2e27fc66f --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_fp8_v3.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 64, + 16, 16, + 64, 16, 8, + 16, 16, + 1, 1, + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 16, 1, 4>, 4, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp16_v3.cpp b/example/01_gemm/gemm_xdl_fp16_v3.cpp new file mode 100644 index 000000000..ad370f570 --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp16_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 224, 256, + 64, 8, 2, + 16, 16, + 7, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 2, 0, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp8_v3.cpp b/example/01_gemm/gemm_xdl_fp8_v3.cpp new file mode 100644 index 000000000..cce8a20ff --- /dev/null +++ b/example/01_gemm/gemm_xdl_fp8_v3.cpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +using ADataType = ck::f8_t; +using BDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + PassThrough, PassThrough, PassThrough, GemmDefault, + 256, + 128, 256, + 128, 16, 16, + 16, 16, + 4, 8, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 1, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 1, + 1, 2, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3, ck::f8_t>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example_v2.inc" + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/run_gemm_example_v2.inc b/example/01_gemm/run_gemm_example_v2.inc new file mode 100644 index 000000000..ff6a4acf7 --- /dev/null +++ b/example/01_gemm/run_gemm_example_v2.inc @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + // give a chance if stride is zero, return a default packed stride + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } +#if 0 + printf("B matrix:\n"); + for (int in = 0; in < N; in++) + { + for (int ik = 0; ik < K; ik++) + { + printf("%02x ", *(reinterpret_cast(&b_k_n(ik,in)))); + if(ik%8==7) printf("|"); + } + printf("\n"); + } +#endif + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 1}); +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#endif + } + + if(config.time_kernel) + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_description/multi_index_transform.hpp b/include/ck/tensor_description/multi_index_transform.hpp index ae3139ce7..f68473c29 100644 --- a/include/ck/tensor_description/multi_index_transform.hpp +++ b/include/ck/tensor_description/multi_index_transform.hpp @@ -1951,4 +1951,89 @@ struct Modulo printf("}"); } }; + +template +struct Xor +{ + using LowerIndex = MultiIndex<2>; + using UpperIndex = MultiIndex<2>; + + using UpLengths = LowLengths; + + UpLengths up_lengths_; + + __host__ __device__ constexpr Xor() : up_lengths_{} {} + + __host__ __device__ constexpr Xor(const LowLengths& low_lengths) : up_lengths_{low_lengths} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 2; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 2; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + + idx_low(Number<1>{}) = + idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up, + Number) const + { + static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 && + UpIdx::Size() == 2, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + + CalculateLowerIndex(idx_low, idx_up); + + idx_diff_low = idx_low - idx_low_old; + } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Xor{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + printf("}"); + } +}; } // namespace ck diff --git a/include/ck/tensor_description/multi_index_transform_helper.hpp b/include/ck/tensor_description/multi_index_transform_helper.hpp index af0a8a34d..342ed82c2 100644 --- a/include/ck/tensor_description/multi_index_transform_helper.hpp +++ b/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -127,4 +127,10 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, { return Modulo{modulus, up_length}; } + +template +__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 7b2aaa76b..5d137e67e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -960,13 +960,13 @@ struct BlockwiseGemmXdlops_pipeline_v4 static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( make_tuple(Number{}, I1, Number{}, Number{}), make_tuple( - Number{}, Number{}, Number{}, I1)); + Number{}, Number{}, Number{}, I1)); // B[N0, N1, N2, KPack] static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( make_tuple(Number{}, I1, Number{}, Number{}), make_tuple( - Number{}, Number{}, Number{}, I1)); + Number{}, Number{}, Number{}, I1)); // C[M, N, NumRegXdlops] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp new file mode 100644 index 000000000..036a40488 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_pipeline_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + static constexpr index_t WaveSize = 64; + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t AMmaKStride = KPack; + static constexpr index_t BMmaKStride = KPack; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = + ck::BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ + BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp new file mode 100644 index 000000000..2558b58c2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp" + +namespace ck { + +enum struct BlockGemmPipelineVersion +{ + v1, // Naive + v2, // Mem + v3, // Comp + v4, // Comp, double lds buffer + v5, // Comp, double global prefetch register buffer +}; + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return BlockwiseGemmXdlops_pipeline_v4{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) + { + return BlockwiseGemmXdlops_pipeline_v5{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp new file mode 100644 index 000000000..0a7ad545b --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -0,0 +1,732 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, + // but except the first, as we can shorten non-MAC cluster a bit and there's no + // observable negative impact. The desired effect is waves in a workgroup + // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC + // resource from other workgroups and reducing the chance of latency hiding by + // waiting for the rest of the workgroup at the eventual sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by + // applying small delays to different wavefronts It is performed + // near the end of MAC cluster to minimize lgkmcnt penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp new file mode 100644 index 000000000..45b1ec341 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -0,0 +1,1154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / (4 * warpSize / BlockSize), + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / (4 * warpSize / BlockSize), + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC + // cluster, but except the first, as we can shorten non-MAC cluster a bit + // and there's no observable negative impact. The desired effect is waves in + // a workgroup executing MAC in sync. This avoids some out-of-sync waves + // hijacking MAC resource from other workgroups and reducing the chance of + // latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp new file mode 100644 index 000000000..9d1301ae2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -0,0 +1,439 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / ds_read_a_issue_cycle; + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / ds_read_b_issue_cycle; + + // stage 1 + // Separate this part? + constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + sizeof(ComputeDataType) / sizeof(BDataType) + ? sizeof(ComputeDataType) / sizeof(ADataType) + : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - num_mfma_per_ds_read * (num_ds_read_inst_a / ds_read_a_mfma_rate + + num_ds_read_inst_b / ds_read_b_mfma_rate); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_ds_read_inst_a / ds_read_a_mfma_rate, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA + }); + + static_for<0, num_ds_read_inst_b / ds_read_b_mfma_rate, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_ds_read, 0); // MFMA + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp new file mode 100644 index 000000000..75569150b --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -0,0 +1,597 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimimal pipeline with highest resource request +// GlobalPrefetchStages: 4 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v4 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v4 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 4; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + template + __device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group) + { + // TODO: Take data type into consideration as pipe ver 3 + // A-B splited schedule + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_a = + (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a; + constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a; + + constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = + (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b; + constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b; + + constexpr auto num_mfma_per_issue = + HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b); + + static_for<0, num_issue_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_a, + schedule_group); // MFMA + }); + + static_for<0, num_issue_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_b, + schedule_group); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0); + + // Local prefill 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch 4 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + auto LoopFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto vmem_buf, + auto mfma_reg_buf, + auto schedule_group) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + }); + + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(schedule_group); + }; + + LoopFunc(I1, I1, I0, I0, I0, I0); + LoopFunc(I0, I0, I1, I1, I1, I0); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto vmem_buf, + auto mfma_reg_buf, + auto schedule_group) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(schedule_group); + }; + + auto ReadCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto mfma_reg_buf, + auto schedule_group) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(schedule_group); + }; + + auto CompFunc = [&](auto mfma_reg_buf) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + // tail + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); + ReadCompFunc(I0, I0, I1, I1); + CompFunc(I0); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I1, I1, I0, I0, I0, I1); + ReadWriteCompFunc(I0, I0, I1, I1, I1, I1); + ReadCompFunc(I1, I1, I0, I1); + CompFunc(I1); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp new file mode 100644 index 000000000..9711f8e41 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -0,0 +1,667 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v5 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v5 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + __device__ static constexpr auto HotLoopScheduler() + { + // TODO: Take data type into consideration as pipe ver 3 + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 8 + ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 8 + ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat; + constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat; + + constexpr auto num_dsread_stage1_a_mfma = + (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage1_b_mfma = + (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + constexpr auto num_dsread_stage3_a_mfma = + (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage3_b_mfma = + (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + constexpr auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate - + num_ds_read_inst_b / ds_read_b_mfma_rate; + constexpr auto num_mfma_per_issue = + num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + // stage 1 + static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // stage 2 + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 3 + static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // IGLP COMPILER BUG: + // If comment out following scheduler barrier would cause sanity fail. + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto vmem_buf) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + } + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I0); + LoopFunc(I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + // tail + auto ReadWriteCompFunc = [&](auto vmem_buf) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); + + block_sync_lds(); + } + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + }; + auto ReadCompFunc = [&]() { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat - 1, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + HotLoopScheduler(); + }; + + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0); + ReadWriteCompFunc(I1); + ReadCompFunc(); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I0); + ReadCompFunc(); + } + } + + protected: + // A[MRepeat, I1, I1, KPack] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, Number{})); + + // B[NRepeat, N1, N2, KPack] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp new file mode 100644 index 000000000..c06b26aa2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmV2 : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp new file mode 100644 index 000000000..9d3e97c3e --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -0,0 +1,687 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(x); y = type_convert(t); } + constexpr const static bool is_pack2_invocable = true; }; struct PassThrough @@ -131,12 +132,24 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(int32_t& y, const int8_t& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(int8_t& y, const float& x) const { y = type_convert(x); } + template <> + __host__ __device__ void operator()(float& y, const int8_t& x) const + { + y = type_convert(x); + } + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> __host__ __device__ void operator()(int4_t& y, const int4_t& x) const diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index a89e14cbd..148aba5aa 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -259,46 +259,20 @@ struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt -struct BlockToCTileMap_Grouped_M00_N0_M01Adapt; +// Grouped Rows of column-vectors WGP mapping +// Optimized for MI300-like multipe-die chip template -struct BlockToCTileMap_Grouped_M00_N0_M01Adapt +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default; - - __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt( - const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; - __host__ __device__ - BlockToCTileMap_Grouped_M00_N0_M01Adapt(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; - __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& - operator=(const BlockToCTileMap_Grouped_M00_N0_M01Adapt&) = default; - __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt& - operator=(BlockToCTileMap_Grouped_M00_N0_M01Adapt&&) = default; - __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { -#if 0 - if(get_thread_global_1d_id()==0){ - printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); - } -#endif - } - - template - __host__ __device__ - BlockToCTileMap_Grouped_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01 = 8) - : BlockToCTileMap_Grouped_M00_N0_M01Adapt( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) - { } __host__ static constexpr index_t CalculateGridSize(index_t M, index_t N) @@ -309,12 +283,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt - __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - template __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { @@ -329,67 +297,82 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt| - * - * NPerBlock NPerBlock NPerBlock NPerBlock - * N_0 N_1 N_2 N_3 - * - |-----------|-----------|-----------|-----|-----|- - * ^ | - - 0 |/----> 2 | | | | - * | | | / | | | | | M_0 MPerBlock - * | M | /| | | | | | - * |-0---|---/-|-----|-----|-----------|-----|-----|- - * | 1 | / | | | blockid | | | - * idxM0 | | | / | V | 5 | | | M_1 MPerBlock - * | - V 1 | - 3 | | | | - * |-----------|-----------|-----------|-----|-----|- - * mtx M | | | | | | - * | | | | | | M_2 MPerBlock - * | | | | | | - * |-----------|-----------|-----------|-----|-----|- - * | | | | | | - * | | | | | | M_3 MPerBlock - * | | | | | | - * |-----------|-----------|-----------|-----|-----|- - * V | | | | | | - * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock - * | | | | | | - * |-----------|-----------|-----------|-----|-----|- - * Example: - * assume: - * M0 = 5 - * N0 = 4 - * block_1d_id = 5 - * M01 = 2 - * - * idx_N0 = 1 - * idx_M0 = 1 - * M01_adapt = 2 - * idx_M00 = 0 - * idx_M01 = 1 - * idx_N0_M01_local = 5 - * output {1, 2} - */ - - return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); + if(M0 == 1) + { + return make_tuple(0, block_1d_id); + } + else if(N0 == 1) + { + return make_tuple(block_1d_id, 0); + } + // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + else + { + const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum); + const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0); + auto group_id_x = block_1d_id % GroupNum; + auto group_id_y = block_1d_id / GroupNum; + auto remap_block_1d_id = + group_id_x <= big_group_num + ? group_id_x * group_size + group_id_y + : group_id_x * group_size + big_group_num - group_id_x + group_id_y; + + index_t idx_N0 = remap_block_1d_id % N0; + index_t idx_M0 = remap_block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } } template @@ -405,15 +388,6 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt -struct BlockToCTileMap_Grouped_M00_N0_M01Adapt - : BlockToCTileMap_Grouped_M00_N0_M01Adapt -{ - using BlockToCTileMap_Grouped_M00_N0_M01Adapt:: - BlockToCTileMap_Grouped_M00_N0_M01Adapt; -}; - // columns of row-vectors // This C-tile map dynamically adjusts N01 when C-tile index is out of range template diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp new file mode 100644 index 000000000..3e0debfa1 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -0,0 +1,1871 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid, + p_shared, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 0b2300fbe..699a34418 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -202,15 +202,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); + // maintain a container record is_src_valid, waiting for RunWrite use. const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + src_oob_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, is_src_valid); using src_vector_type = vector_type_maker_t; using src_vector_t = typename src_vector_type::type; - // copy data from src_buf into src_vector_container - auto src_vector_container = src_vector_type{ - src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + auto src_vector_container = + src_vector_type{src_buf.template Get(src_coord_.GetOffset(), true)}; using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; @@ -305,12 +307,78 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; }); #else + + // OOB Check + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using vector_t = typename vector_type_maker::type::type; + + auto op_r = src_thread_scratch_tuple_(thread_scratch_id) + .template GetAsType(src_data_idx_seq); + + const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) + .template GetAsType(src_data_idx_seq); + + auto op_r_v = is_src_valid ? op_r : vector_t(0); + + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, op_r_v); + }); + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // TODO make this logic more generic for more sub-dword datatype if constexpr(SrcVectorDim != DstVectorDim && ((is_same>::value && SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) { // each transpose does @@ -386,6 +454,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Number thread_scratch_id = Number{}) { // if there is transpose, it's done here + // if there is oob check, it's done here // TODO move this elsewhere TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); @@ -738,6 +807,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); } + __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + return make_naive_tensor_descriptor_packed(src_access_lengths); + } + __device__ static constexpr auto GetDstThreadScratchDescriptor() { // 1st stage of transforms @@ -789,6 +868,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 private: static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; using SrcThreadScratch = @@ -798,6 +879,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 decltype(src_thread_scratch_desc_), true>; + using SrcOOBThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; StaticallyIndexedArray src_thread_scratch_tuple_; + StaticallyIndexedArray src_oob_thread_scratch_tuple_; DstThreadScratch dst_thread_scratch_; diff --git a/include/ck/utility/blkgemmpipe_scheduler.hpp b/include/ck/utility/blkgemmpipe_scheduler.hpp new file mode 100644 index 000000000..902195e2f --- /dev/null +++ b/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct BlockGemmPipelineScheduler +{ + Intrawave, + Interwave, +}; + +enum struct TailNumber +{ + // Single / Double buffer pipeline + Odd, + Even, + + // Long prefetch pipeline, up to 8 + One, + Two, + Three, + Four, + Five, + Six, + Seven, + + // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages + Empty, + // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add + // prefetchstages + Full, +}; +template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; + static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num); + } +}; + +} // namespace ck diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 4d6791b5a..93a1edefb 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -163,6 +163,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = bool; + static constexpr index_t vector_size = 1; +}; + template struct vector_type { diff --git a/include/ck/utility/synchronization.hpp b/include/ck/utility/synchronization.hpp index e653d46d3..4fe5e3950 100644 --- a/include/ck/utility/synchronization.hpp +++ b/include/ck/utility/synchronization.hpp @@ -10,10 +10,12 @@ namespace ck { __device__ void block_sync_lds() { #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - asm volatile("\ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); #else __syncthreads(); #endif diff --git a/include/ck/utility/transpose_vectors.hpp b/include/ck/utility/transpose_vectors.hpp index 6faf5c133..e73ec03de 100644 --- a/include/ck/utility/transpose_vectors.hpp +++ b/include/ck/utility/transpose_vectors.hpp @@ -162,4 +162,83 @@ struct transpose_vectors } }; +// transpose f8 4x4 +__device__ void transpose_f8_4x4(const f8x4_t& x0, + const f8x4_t& x1, + const f8x4_t& x2, + const f8x4_t& x3, + f8x4_t& y0, + f8x4_t& y1, + f8x4_t& y2, + f8x4_t& y3) +{ + int32_t t0, t1; + int32_t z0, z1, z2, z3; + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + t0 = __builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m0); + t1 = __builtin_amdgcn_perm(bit_cast(x3), bit_cast(x2), m0); + z0 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m1); + z1 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m2); + t0 = __builtin_amdgcn_perm(bit_cast(x1), bit_cast(x0), m3); + t1 = __builtin_amdgcn_perm(bit_cast(x3), bit_cast(x2), m3); + z2 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m1); + z3 = __builtin_amdgcn_perm(bit_cast(t1), bit_cast(t0), m2); + + y0 = bit_cast(z0); + y1 = bit_cast(z1); + y2 = bit_cast(z2); + y3 = bit_cast(z3); +} + +template +struct transpose_vectors +{ + // we got [NY * NX] amount of S data to be transposed + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = f8_t; + using VX = vector_type; + using VY = vector_type; + + __device__ void operator()(const StaticallyIndexedArray& vx_tuple, + StaticallyIndexedArray& vy_tuple) + { + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // reference to 4 f8 data from vx_tuple + const auto& x_s4_0 = vx_tuple[ix].template AsType()[iy / I4]; + const auto& x_s4_1 = vx_tuple[ix + I1].template AsType()[iy / I4]; + const auto& x_s4_2 = vx_tuple[ix + I2].template AsType()[iy / I4]; + const auto& x_s4_3 = vx_tuple[ix + I3].template AsType()[iy / I4]; + + // reference to 4 f8 data from vy_tuple + auto& y_s4_0 = vy_tuple(iy).template AsType()(ix / I4); + auto& y_s4_1 = vy_tuple(iy + I1).template AsType()(ix / I4); + auto& y_s4_2 = vy_tuple(iy + I2).template AsType()(ix / I4); + auto& y_s4_3 = vy_tuple(iy + I3).template AsType()(ix / I4); + + // transpose + transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3); + }); + }); + } +}; + } // namespace ck diff --git a/include/ck/utility/type.hpp b/include/ck/utility/type.hpp index 9609afba4..cc011d722 100644 --- a/include/ck/utility/type.hpp +++ b/include/ck/utility/type.hpp @@ -43,6 +43,8 @@ __host__ __device__ constexpr Y bit_cast(const X& x) #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST Y y; + // auto t = reinterpret_cast(&x); + // y = *t; __builtin_memcpy(&y, &x, sizeof(X)); return y; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 4d52563f4..e1edc4fae 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -21,7 +21,7 @@ template struct ReferenceGemm : public device::BaseOperator { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp new file mode 100644 index 000000000..4047f0096 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp @@ -0,0 +1,505 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +#ifdef CK_ENABLE_FP16 +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +#endif +#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances); + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances); +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmV2> +{ + using DeviceOp = DeviceGemmV2; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + } +#endif +#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + op_ptrs); + + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + op_ptrs); + add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + op_ptrs); + } + } +#endif + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp index 6fd7ed8aa..e87811b76 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -31,6 +31,18 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bhalf_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + template <> struct GeneratorTensor_1 { @@ -43,6 +55,20 @@ struct GeneratorTensor_1 } }; +#if defined CK_ENABLE_FP8 +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bhalf_t operator()(Is...) + { + return ck::type_convert(value); + } +}; +#endif + template <> struct GeneratorTensor_1 { diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt new file mode 100644 index 000000000..41ce4a092 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/CMakeLists.txt @@ -0,0 +1,70 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_INSTANCES) + +list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp + + + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp + + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp + device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp + ) + +add_instance_library(device_gemm_universal_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp new file mode 100644 index 000000000..d34c83a60 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 4, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 4, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 4, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 000000000..41d6481c9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..de41821d9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..cdde9fa43 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..04237cc62 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..18bd4fcf3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..7661a2b5d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..887c9d4df --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..b6869d801 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..bc59f2ebf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..baf2cb3c4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp new file mode 100644 index 000000000..ca90efa4c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + // AGPR Spill + // DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + // AGPR Spill when use permuted lds layout. so, use padding for these two. + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 8, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 000000000..77addd6ad --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..4fb034d3b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..56fb3a129 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..b02d57c2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..7f7ec14ba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..32634a612 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..2aa313851 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..3062add94 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..ede5e4c42 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..a1928ccc6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp new file mode 100644 index 000000000..452a9c963 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 4, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 8, 4, 16, 16, 1, 1, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 8, 4, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 4, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 4, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 000000000..f3e96e83f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..d73b75fcd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..19894a440 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..f1123e571 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..59e481fa7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..4a0990756 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..e2804cbf9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..e6e41282f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..815bce781 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..33ca04996 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp new file mode 100644 index 000000000..78d16670c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 8, 16, 32, 32, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 8, 16, 16, 16, 4, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 8, 16, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 8, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 8, 16, 16, 16, 1, 1, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 8, 16, 16, 16, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 8, 16, 32, 32, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 8, 16, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 16, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 16, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 000000000..b6d916f26 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..e72a748e9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..203ada9a3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..da3870567 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..ea23dab56 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..0caccb4f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..13f34a1c5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..bb2961d22 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..ebdc104f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..b601313ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp new file mode 100644 index 000000000..f9bdde77f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 64, 16, 8, 32, 32, 3, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 2, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 2, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 2, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 16, 4, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 16, 2, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 16, 4, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 16, 2, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 2, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 2, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 4, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp new file mode 100644 index 000000000..29e96d9cb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..30c314bde --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..84ac5b501 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..1684aff67 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..95a54e556 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..a1930526a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..0c53f3555 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..44883a5bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..87871f39f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..10c28742e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,27 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp new file mode 100644 index 000000000..af4008c91 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 16, 8, 16, 16, 8, 7, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> + // clang-format on + >; + +template +using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 128, 16, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 128, 16, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 128, 16, 8, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 128, 16, 8, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 128, 16, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 128, 16, 8, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 8, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 128, 16, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 128, 16, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 128, 16, 8, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 16, 8, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp new file mode 100644 index 000000000..49236e17f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp new file mode 100644 index 000000000..8cc83f73e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp new file mode 100644 index 000000000..74299cd55 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instance.cpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp new file mode 100644 index 000000000..d0561c0af --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instance.cpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp new file mode 100644 index 000000000..7c8820856 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp new file mode 100644 index 000000000..3713223d4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp new file mode 100644 index 000000000..2814dee43 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 000000000..aead87940 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp new file mode 100644 index 000000000..69078655e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp new file mode 100644 index 000000000..ae0816ffc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances( + std::vector>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_impl.hpp b/profiler/include/profiler/profile_gemm_universal_impl.hpp new file mode 100644 index 000000000..c77541e0e --- /dev/null +++ b/profiler/include/profiler/profile_gemm_universal_impl.hpp @@ -0,0 +1,299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_universal.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_universal_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch, + int n_warmup, + int n_iter) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmV2; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 19, 20, 32, 38}; + + if(KBatch > 0) + { + kbatch_list = {KBatch}; + } + + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " + << kbatch_curr << std::endl; + +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + } +#endif + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch + << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index e8992070b..ce813d05a 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -49,6 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_SOURCES profile_gemm_add_multiply.cpp) list(APPEND PROFILER_SOURCES profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_SOURCES profile_gemm_splitk.cpp) + list(APPEND PROFILER_SOURCES profile_gemm_universal.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu.cpp) list(APPEND PROFILER_SOURCES profile_conv_fwd_bias_relu_add.cpp) list(APPEND PROFILER_SOURCES profile_conv_bwd_data.cpp) @@ -115,6 +116,7 @@ if(GPU_TARGETS MATCHES "gfx9") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_universal_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_add_multiply_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bias_add_reduce_instance) diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp new file mode 100644 index 000000000..940ef09e5 --- /dev/null +++ b/profiler/src/profile_gemm_universal.cpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_universal_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F16_F16, // 4 + F16_F8_F16, // 5 + F16_F16_F16_F8, // 6 +}; + +#define OP_NAME "gemm_universal" +#define OP_DESC "Universal GEMM" + +int profile_gemm_universal(int argc, char* argv[]) +{ + if(argc != 15 && argc != 17) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, " + "comp f8)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + printf("optional:\n"); + printf("arg15: number of warm-up cycles (default 1)\n"); + printf("arg16: number of iterations (default 10)\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int KBatch = std::stoi(argv[14]); + + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); + } + + using F32 = float; + using F16 = ck::half_t; + using F8 = ck::f8_t; + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_universal_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + KBatch, + n_warmup, + n_iter); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); + } + else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_universal); diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 51d6f7a30..a28e72357 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -8,7 +8,7 @@ MY_PROJECT_SOURCE=$1 cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-Xclang -mllvm -Xclang -enable-post-misched=0 -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS="gfx908;gfx90a;gfx940" \ diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bbb75c49e..33aa10df7 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -149,6 +149,7 @@ add_subdirectory(gemm) add_subdirectory(gemm_add) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_split_k) +add_subdirectory(gemm_universal) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) diff --git a/test/gemm_universal/CMakeLists.txt b/test/gemm_universal/CMakeLists.txt new file mode 100644 index 000000000..4aab6323c --- /dev/null +++ b/test/gemm_universal/CMakeLists.txt @@ -0,0 +1,4 @@ +add_gtest_executable(test_gemm_universal test_gemm_universal_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_universal PRIVATE utility device_gemm_universal_instance) + endif() diff --git a/test/gemm_universal/test_gemm_universal_ut_cases.inc b/test/gemm_universal/test_gemm_universal_ut_cases.inc new file mode 100644 index 000000000..e3f8f8a26 --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_ut_cases.inc @@ -0,0 +1,113 @@ +#pragma once + +TYPED_TEST(TestGemmUniversal_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 320; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_MK_KN, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_MK_NK, PaddK) +{ + std::vector Ms{127}; + constexpr int N = 512; + constexpr int K = 437; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmUniversal_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 512; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_universal/test_gemm_universal_util.hpp b/test/gemm_universal/test_gemm_universal_util.hpp new file mode 100644 index 000000000..9f101191d --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_util.hpp @@ -0,0 +1,91 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_universal_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmUniversal : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using CDataType = std::tuple_element_t<4, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_gemm_universal_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl.cpp new file mode 100644 index 000000000..0c485e02a --- /dev/null +++ b/test/gemm_universal/test_gemm_universal_xdl.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_universal_util.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversal_MK_KN + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +template +class TestGemmUniversal_MK_NK + : public ck::test::TestGemmUniversal, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ADataType, BDataType, CDataType + std::tuple< F16, F16, F16>, + std::tuple< F16, F8, F16>, + std::tuple< F8, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversal_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmUniversal_MK_NK, KernelTypes); + +#include "test_gemm_universal_ut_cases.inc" -- GitLab From dd34ab6e64c3ecebe27e751a39c4472bf34b81f3 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 15 Apr 2024 08:01:22 -0700 Subject: [PATCH 29/63] add CK_USE_XDL/WMMA for client examples (#1238) --- client_example/CMakeLists.txt | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 3aa9efa31..8eb662d28 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -48,6 +48,21 @@ else() endif() endif() +if (GPU_TARGETS) + if (GPU_TARGETS MATCHES "gfx9") + add_definitions(-DCK_USE_XDL) + set(CK_USE_XDL "ON") + endif() + if (GPU_TARGETS MATCHES "gfx11") + add_definitions(-DCK_USE_WMMA) + set(CK_USE_WMMA "ON") + endif() +else() + add_definitions(-DCK_USE_WMMA -DCK_USE_XDL) + set(CK_USE_XDL "ON") + set(CK_USE_WMMA "ON") +endif() + find_package(composable_kernel COMPONENTS device_other_operations device_gemm_operations device_conv_operations device_reduction_operations) if(GPU_TARGETS MATCHES "gfx9") find_package(composable_kernel COMPONENTS device_contraction_operations) -- GitLab From db376dd8a4eb36c6f8e9b100b89d8a1371d76f4c Mon Sep 17 00:00:00 2001 From: carlushuang Date: Tue, 16 Apr 2024 08:27:12 +0800 Subject: [PATCH 30/63] introducing ck_tile! (#1216) * enable gfx940 * switch between intrinsic mfma routines on mi100/200 and mi300 * fix mfma_int8 on MI300 * disable 2 int8 examples on MI300 * Update cmake-ck-dev.sh * restore gitignore file * modify Jenkinsfile to the internal repo * Bump rocm-docs-core from 0.24.0 to 0.29.0 in /docs/sphinx Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.24.0 to 0.29.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/RadeonOpenCompute/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.24.0...v0.29.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * initial enablement of gfx950 * fix clang format * disable examples 31 and 41 int8 on gfx950 * add code * fix build wip * fix xx * now can build * naming * minor fix * wip fix * fix macro for exp2; fix warpgemm a/b in transposedC * unify as tuple_array * Update the required Python version to 3.9 * Update executable name in test scripts * re-structure tuple/array to avoid spill * Merge function templates * Fix format * Add constraint to array<> ctor * Re-use function * Some minor changes * remove wrong code in store_raw() * fix compile issue in transpose * Rename enum Rename 'cood_transform_enum' to 'coord_transform_enum' * let more integral_constant->constant, and formating * make sure thread_buffer can be tuple/array * temp fix buffer_store spill * not using custom data type by default, now we can have ISA-level same code as opt_padding * fix compile error, fp8 not ready now * fix fp8 duplicated move/shift/and/or problem * Default use CK_TILE_FLOAT_TO_FP8_STOCHASTIC rounding mode * fix scratch in fp8 kernel * update some readme * fix merge from upstream * sync with upstream * sync upstream again * sync 22 * remove unused * fix clang-format * update README of ck_tile example * fix several issue * let python version to be 3.8 as minimal * remove ck_tile example from default cmake target like all/install/check * remove mistake * 1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg * fix some bug in group-mode masking and codegen. update README * F8 quantization for FMHA forward (#1224) * Add SAccElementFunction, PComputeElementFunction, OAccElementFunction in pipeline * Add element function to fmha api * Adjust P elementwise function * Fix bug of elementwise op, our elementwise op is not inout * Add some elementwise op, prepare to quantization * Let generate.py can generate different elementwise function * To prevent compiler issue, remove the elementwise function we have not used. * Remove f8 pipeline, we should share the same pipeline even in f8 * Remove remove_cvref_t * Avoid warning * Fix wrong fp8 QK/KV block gemm setting * Check fp8 rounding error in check_err() * Set fp8 rounding error for check_err() * Use CK_TILE_FLOAT_TO_FP8_STANDARD as default fp8 rounding mode * 1. codgen the f8 api and kernel 2. f8 host code * prevent warning in filter mode * Remove not-in-use elementwise function kargs * Remove more not-in-use elementwise function kargs * Small refinements in C++ source files * Use conditional_t<> to simplify code * Support heterogeneous argument for binary function types * Re-use already-existing scales<> functor template * Fix wrong value produced by saturating * Generalize the composes<> template * Unify saturates<> implementation * Fix type errors in composes<> * Extend less_equal<> * Reuse the existing template less_equal<> in check_err() * Add equal & equal * Rename check_err() parameter * Rename check_err() parameter * Add FIXME comment for adding new macro in future * Remove unnecessary cast to void * Eliminate duplicated code * Avoid dividing api pool into more than 2 groups * Use more clear variable names * Use affirmative condition in if stmt * Remove blank lines * Donot perfect forwarding in composes<> * To fix compile error, revert generate.py back to 4439cc107dd90302d68a6494bdd33113318709f8 * Fix bug of p element function * Add compute element op to host softmax * Remove element function in api interface * Extract user parameter * Rename pscale and oscale variable * rename f8 to fp8 * rename more f8 to fp8 * Add pipeline::operator() without element_functor * 1. Remove deprecated pipeline enum 2. Refine host code parameter * Use quantization range as input * 1. Rename max_dtype to dtype_max. 2. Rename scale to scale_s 3.Add init description * Refine description * prevent early return * unify _squant kernel name in cpp, update README * Adjust the default range. * Refine error message and bias range * Add fp8 benchmark and smoke test * fix fp8 swizzle_factor=4 case --------- Co-authored-by: Po Yen Chen Co-authored-by: carlushuang --------- Signed-off-by: dependabot[bot] Co-authored-by: illsilin Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Jing Zhang Co-authored-by: zjing14 Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Po-Yen, Chen Co-authored-by: rocking --- .gitignore | 2 + CMakeLists.txt | 2 + cmake/EnableCompilerWarnings.cmake | 1 + example/ck_tile/01_fmha/CMakeLists.txt | 44 + example/ck_tile/01_fmha/README.md | 117 + example/ck_tile/01_fmha/fmha_fwd.cpp | 716 ++++++ example/ck_tile/01_fmha/fmha_fwd.hpp | 269 +++ example/ck_tile/01_fmha/generate.py | 605 +++++ example/ck_tile/01_fmha/mask.hpp | 159 ++ example/ck_tile/01_fmha/misc/gamc.png | Bin 0 -> 30073 bytes example/ck_tile/01_fmha/script/benchmark.sh | 32 + example/ck_tile/01_fmha/script/smoke_test.sh | 45 + example/ck_tile/01_fmha/utils.hpp | 92 + example/ck_tile/CMakeLists.txt | 5 + example/ck_tile/remod.py | 21 + include/ck_tile/README.md | 48 + include/ck_tile/core.hpp | 58 + include/ck_tile/core/README.md | 18 + .../core/algorithm/cluster_descriptor.hpp | 38 + .../core/algorithm/coordinate_transform.hpp | 1672 ++++++++++++++ .../core/algorithm/space_filling_curve.hpp | 166 ++ .../core/arch/amd_buffer_addressing.hpp | 2031 +++++++++++++++++ include/ck_tile/core/arch/arch.hpp | 93 + include/ck_tile/core/arch/utility.hpp | 62 + include/ck_tile/core/config.hpp | 156 ++ include/ck_tile/core/container/array.hpp | 251 ++ .../core/container/container_helper.hpp | 499 ++++ include/ck_tile/core/container/map.hpp | 164 ++ .../core/container/meta_data_buffer.hpp | 99 + .../ck_tile/core/container/multi_index.hpp | 100 + include/ck_tile/core/container/sequence.hpp | 1114 +++++++++ include/ck_tile/core/container/span.hpp | 78 + .../container/statically_indexed_array.hpp | 41 + .../ck_tile/core/container/thread_buffer.hpp | 165 ++ include/ck_tile/core/container/tuple.hpp | 781 +++++++ include/ck_tile/core/numeric/bfloat16.hpp | 342 +++ include/ck_tile/core/numeric/float8.hpp | 871 +++++++ include/ck_tile/core/numeric/half.hpp | 385 ++++ include/ck_tile/core/numeric/integer.hpp | 13 + .../core/numeric/integral_constant.hpp | 83 + include/ck_tile/core/numeric/math.hpp | 539 +++++ include/ck_tile/core/numeric/numeric.hpp | 191 ++ include/ck_tile/core/numeric/type_convert.hpp | 66 + include/ck_tile/core/numeric/vector_type.hpp | 185 ++ include/ck_tile/core/tensor/buffer_view.hpp | 1068 +++++++++ include/ck_tile/core/tensor/load_tile.hpp | 81 + include/ck_tile/core/tensor/null_tensor.hpp | 12 + .../ck_tile/core/tensor/null_tile_window.hpp | 88 + include/ck_tile/core/tensor/shuffle_tile.hpp | 177 ++ include/ck_tile/core/tensor/slice_tile.hpp | 92 + .../core/tensor/static_distributed_tensor.hpp | 190 ++ include/ck_tile/core/tensor/store_tile.hpp | 93 + include/ck_tile/core/tensor/sweep_tile.hpp | 30 + .../ck_tile/core/tensor/tensor_adaptor.hpp | 945 ++++++++ .../core/tensor/tensor_adaptor_coordinate.hpp | 257 +++ .../ck_tile/core/tensor/tensor_coordinate.hpp | 92 + .../ck_tile/core/tensor/tensor_descriptor.hpp | 467 ++++ include/ck_tile/core/tensor/tensor_view.hpp | 281 +++ .../ck_tile/core/tensor/tile_distribution.hpp | 759 ++++++ .../tensor/tile_distribution_encoding.hpp | 760 ++++++ .../ck_tile/core/tensor/tile_elementwise.hpp | 263 +++ include/ck_tile/core/tensor/tile_window.hpp | 740 ++++++ include/ck_tile/core/utility/bit_cast.hpp | 19 + include/ck_tile/core/utility/functional.hpp | 208 ++ include/ck_tile/core/utility/ignore.hpp | 22 + include/ck_tile/core/utility/magic_div.hpp | 240 ++ include/ck_tile/core/utility/random.hpp | 58 + include/ck_tile/core/utility/to_sequence.hpp | 73 + .../core/utility/transpose_vectors.hpp | 125 + include/ck_tile/core/utility/type_traits.hpp | 95 + .../core/utility/unary_element_function.hpp | 67 + include/ck_tile/host.hpp | 22 + include/ck_tile/host/arg_parser.hpp | 184 ++ include/ck_tile/host/check_err.hpp | 394 ++++ include/ck_tile/host/device_memory.hpp | 112 + include/ck_tile/host/fill.hpp | 232 ++ include/ck_tile/host/hip_check_error.hpp | 36 + include/ck_tile/host/host_tensor.hpp | 523 +++++ include/ck_tile/host/kernel_launch.hpp | 166 ++ include/ck_tile/host/ranges.hpp | 69 + .../reference_batched_elementwise.hpp | 64 + .../host/reference/reference_batched_gemm.hpp | 50 + .../reference/reference_batched_masking.hpp | 32 + .../reference/reference_batched_softmax.hpp | 71 + .../ck_tile/host/reference/reference_gemm.hpp | 50 + .../host/reference/reference_im2col.hpp | 61 + .../host/reference/reference_reduce.hpp | 32 + .../host/reference/reference_softmax.hpp | 51 + include/ck_tile/host/stream_config.hpp | 17 + include/ck_tile/ops/common.hpp | 6 + include/ck_tile/ops/common/README.md | 4 + include/ck_tile/ops/common/tensor_layout.hpp | 412 ++++ include/ck_tile/ops/epilogue.hpp | 7 + .../ops/epilogue/default_2d_epilogue.hpp | 50 + include/ck_tile/ops/fmha.hpp | 21 + .../ck_tile/ops/fmha/block/block_masking.hpp | 366 +++ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 716 ++++++ .../fmha/kernel/fmha_fwd_tile_partitioner.hpp | 54 + .../pipeline/block_fmha_pipeline_enum.hpp | 16 + .../pipeline/block_fmha_pipeline_problem.hpp | 61 + .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 597 +++++ .../block_fmha_pipeline_qr_ks_vs_async.hpp | 695 ++++++ ...pipeline_qr_ks_vs_async_default_policy.hpp | 19 + ..._fmha_pipeline_qr_ks_vs_default_policy.hpp | 19 + .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 507 ++++ .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 588 +++++ ..._fmha_pipeline_qs_ks_vs_default_policy.hpp | 19 + ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 959 ++++++++ .../ops/fmha/pipeline/tile_fmha_shape.hpp | 46 + .../ops/fmha/pipeline/tile_fmha_traits.hpp | 30 + include/ck_tile/ops/gemm.hpp | 31 + .../block_gemm_areg_bgmem_creg_problem.hpp | 25 + .../block/block_gemm_areg_bgmem_creg_v1.hpp | 135 ++ ...gemm_areg_bgmem_creg_v1_default_policy.hpp | 110 + .../block_gemm_areg_bsmem_creg_problem.hpp | 26 + .../block/block_gemm_areg_bsmem_creg_v1.hpp | 340 +++ ..._gemm_areg_bsmem_creg_v1_custom_policy.hpp | 36 + ...gemm_areg_bsmem_creg_v1_default_policy.hpp | 56 + .../block/block_gemm_areg_bsmem_creg_v2.hpp | 227 ++ ..._gemm_areg_bsmem_creg_v2_custom_policy.hpp | 36 + ...gemm_areg_bsmem_creg_v2_default_policy.hpp | 46 + .../block_gemm_asmem_bsmem_creg_problem.hpp | 26 + .../block/block_gemm_asmem_bsmem_creg_v1.hpp | 213 ++ ...gemm_asmem_bsmem_creg_v1_custom_policy.hpp | 38 + ...emm_asmem_bsmem_creg_v1_default_policy.hpp | 55 + ...lock_gemm_pipeline_agmem_bgmem_creg_v1.hpp | 200 ++ ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 251 ++ ...lock_gemm_pipeline_agmem_bgmem_creg_v2.hpp | 218 ++ ...ine_agmem_bgmem_creg_v2_default_policy.hpp | 18 + .../pipeline/block_gemm_pipeline_problem.hpp | 25 + .../ops/gemm/pipeline/tile_gemm_shape.hpp | 18 + include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 105 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 471 ++++ .../warp/warp_gemm_attribute_mfma_impl.hpp | 379 +++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 65 + .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 74 + include/ck_tile/ops/reduce.hpp | 7 + .../ck_tile/ops/reduce/block/block_reduce.hpp | 211 ++ include/ck_tile/remod.py | 88 + script/cmake-ck-dev.sh | 8 +- script/cmake-ck-release.sh | 8 +- 141 files changed, 30623 insertions(+), 2 deletions(-) create mode 100644 example/ck_tile/01_fmha/CMakeLists.txt create mode 100644 example/ck_tile/01_fmha/README.md create mode 100644 example/ck_tile/01_fmha/fmha_fwd.cpp create mode 100644 example/ck_tile/01_fmha/fmha_fwd.hpp create mode 100644 example/ck_tile/01_fmha/generate.py create mode 100644 example/ck_tile/01_fmha/mask.hpp create mode 100644 example/ck_tile/01_fmha/misc/gamc.png create mode 100755 example/ck_tile/01_fmha/script/benchmark.sh create mode 100755 example/ck_tile/01_fmha/script/smoke_test.sh create mode 100644 example/ck_tile/01_fmha/utils.hpp create mode 100644 example/ck_tile/CMakeLists.txt create mode 100644 example/ck_tile/remod.py create mode 100644 include/ck_tile/README.md create mode 100644 include/ck_tile/core.hpp create mode 100644 include/ck_tile/core/README.md create mode 100644 include/ck_tile/core/algorithm/cluster_descriptor.hpp create mode 100644 include/ck_tile/core/algorithm/coordinate_transform.hpp create mode 100644 include/ck_tile/core/algorithm/space_filling_curve.hpp create mode 100644 include/ck_tile/core/arch/amd_buffer_addressing.hpp create mode 100644 include/ck_tile/core/arch/arch.hpp create mode 100644 include/ck_tile/core/arch/utility.hpp create mode 100644 include/ck_tile/core/config.hpp create mode 100644 include/ck_tile/core/container/array.hpp create mode 100644 include/ck_tile/core/container/container_helper.hpp create mode 100644 include/ck_tile/core/container/map.hpp create mode 100644 include/ck_tile/core/container/meta_data_buffer.hpp create mode 100644 include/ck_tile/core/container/multi_index.hpp create mode 100644 include/ck_tile/core/container/sequence.hpp create mode 100644 include/ck_tile/core/container/span.hpp create mode 100644 include/ck_tile/core/container/statically_indexed_array.hpp create mode 100644 include/ck_tile/core/container/thread_buffer.hpp create mode 100644 include/ck_tile/core/container/tuple.hpp create mode 100644 include/ck_tile/core/numeric/bfloat16.hpp create mode 100644 include/ck_tile/core/numeric/float8.hpp create mode 100644 include/ck_tile/core/numeric/half.hpp create mode 100644 include/ck_tile/core/numeric/integer.hpp create mode 100644 include/ck_tile/core/numeric/integral_constant.hpp create mode 100644 include/ck_tile/core/numeric/math.hpp create mode 100644 include/ck_tile/core/numeric/numeric.hpp create mode 100644 include/ck_tile/core/numeric/type_convert.hpp create mode 100644 include/ck_tile/core/numeric/vector_type.hpp create mode 100644 include/ck_tile/core/tensor/buffer_view.hpp create mode 100644 include/ck_tile/core/tensor/load_tile.hpp create mode 100644 include/ck_tile/core/tensor/null_tensor.hpp create mode 100644 include/ck_tile/core/tensor/null_tile_window.hpp create mode 100644 include/ck_tile/core/tensor/shuffle_tile.hpp create mode 100644 include/ck_tile/core/tensor/slice_tile.hpp create mode 100644 include/ck_tile/core/tensor/static_distributed_tensor.hpp create mode 100644 include/ck_tile/core/tensor/store_tile.hpp create mode 100644 include/ck_tile/core/tensor/sweep_tile.hpp create mode 100644 include/ck_tile/core/tensor/tensor_adaptor.hpp create mode 100644 include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp create mode 100644 include/ck_tile/core/tensor/tensor_coordinate.hpp create mode 100644 include/ck_tile/core/tensor/tensor_descriptor.hpp create mode 100644 include/ck_tile/core/tensor/tensor_view.hpp create mode 100644 include/ck_tile/core/tensor/tile_distribution.hpp create mode 100644 include/ck_tile/core/tensor/tile_distribution_encoding.hpp create mode 100644 include/ck_tile/core/tensor/tile_elementwise.hpp create mode 100644 include/ck_tile/core/tensor/tile_window.hpp create mode 100644 include/ck_tile/core/utility/bit_cast.hpp create mode 100644 include/ck_tile/core/utility/functional.hpp create mode 100644 include/ck_tile/core/utility/ignore.hpp create mode 100644 include/ck_tile/core/utility/magic_div.hpp create mode 100644 include/ck_tile/core/utility/random.hpp create mode 100644 include/ck_tile/core/utility/to_sequence.hpp create mode 100644 include/ck_tile/core/utility/transpose_vectors.hpp create mode 100644 include/ck_tile/core/utility/type_traits.hpp create mode 100644 include/ck_tile/core/utility/unary_element_function.hpp create mode 100644 include/ck_tile/host.hpp create mode 100644 include/ck_tile/host/arg_parser.hpp create mode 100644 include/ck_tile/host/check_err.hpp create mode 100644 include/ck_tile/host/device_memory.hpp create mode 100644 include/ck_tile/host/fill.hpp create mode 100644 include/ck_tile/host/hip_check_error.hpp create mode 100644 include/ck_tile/host/host_tensor.hpp create mode 100644 include/ck_tile/host/kernel_launch.hpp create mode 100644 include/ck_tile/host/ranges.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_elementwise.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_gemm.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_masking.hpp create mode 100644 include/ck_tile/host/reference/reference_batched_softmax.hpp create mode 100644 include/ck_tile/host/reference/reference_gemm.hpp create mode 100644 include/ck_tile/host/reference/reference_im2col.hpp create mode 100644 include/ck_tile/host/reference/reference_reduce.hpp create mode 100644 include/ck_tile/host/reference/reference_softmax.hpp create mode 100644 include/ck_tile/host/stream_config.hpp create mode 100644 include/ck_tile/ops/common.hpp create mode 100644 include/ck_tile/ops/common/README.md create mode 100644 include/ck_tile/ops/common/tensor_layout.hpp create mode 100644 include/ck_tile/ops/epilogue.hpp create mode 100644 include/ck_tile/ops/epilogue/default_2d_epilogue.hpp create mode 100644 include/ck_tile/ops/fmha.hpp create mode 100644 include/ck_tile/ops/fmha/block/block_masking.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp create mode 100644 include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp create mode 100644 include/ck_tile/ops/gemm.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_problem.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp create mode 100644 include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp create mode 100644 include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp create mode 100644 include/ck_tile/ops/reduce.hpp create mode 100644 include/ck_tile/ops/reduce/block/block_reduce.hpp create mode 100644 include/ck_tile/remod.py diff --git a/.gitignore b/.gitignore index 090594a8d..f4d5ff7ab 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,5 @@ build*/ # Python virtualenv .venv/ +# Python cache +__pycache__/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b9721d05..f6dcde444 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,8 @@ set(version 1.1.0) project(composable_kernel VERSION ${version} LANGUAGES CXX) include(CTest) +find_package(Python3 3.8 COMPONENTS Interpreter REQUIRED) + list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") if (DTYPES) diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index 87cb8cdf1..8654170b3 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -95,6 +95,7 @@ else() -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage + -Wno-unused-lambda-capture ) else() if (CMAKE_${COMPILER}_COMPILER_ID MATCHES "GNU" AND ${COMPILER} MATCHES "CXX") diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt new file mode 100644 index 000000000..e31c96caa --- /dev/null +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -0,0 +1,44 @@ +# generate a list of kernels, but not actually emit files at config stage +execute_process( + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt +) + +# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory +# as current cmake list, otherwise will not figure out the dependency properly +file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) + +add_custom_command( + OUTPUT ${FMHA_FWD_GEN_BLOBS} + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py + --output_dir ${CMAKE_CURRENT_BINARY_DIR} +) + +set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding tile_example ${EXAMPLE_NAME}") +add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) +target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) + +# NOTE: this is dangerous since will change the whole kernel to flush denormals +# WIP with compiler team for an exp2 intrinsic..., then remove this +if(NOT DEFINED FMHA_FWD_FAST_EXP2) + set(FMHA_FWD_FAST_EXP2 true) +endif() + +set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +# ... because they are auto-generated +if(FMHA_FWD_FAST_EXP2) + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) +else() + list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) +endif() + +# Allow comparing floating points directly in order to check sentinel values +list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) + +target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md new file mode 100644 index 000000000..5a428e4d4 --- /dev/null +++ b/example/ck_tile/01_fmha/README.md @@ -0,0 +1,117 @@ +# fused multi-head attention + +This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_fmha_fwd -j +``` +This will result in an executable `build/bin/tile_example_fmha_fwd` + +## kernel +The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template. + +There are 3 template parameters for this kernel template. +* `TilePartitioner` is used to map the workgroup to corresponding tile, `fmha_fwd_tile_partitioner.hpp` in this folder served as this purpose. +* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)). +* `EpiloguePipeline` will modify and store out the result in the last phase. People usually will do lot of post-fusion at this stage, so we also abstract this concept. Currently we didn't do much thing at the epilogue stage but leave the room for future possible support. + +## codegen +To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable. + +## executable +`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all supported args. Below is an example of the output (may subject to change) +``` +args: + -v weather do CPU validation or not (default:1) + -mode kernel mode. 0:batch, 1:group (default:0) + -b batch size (default:2) + -h num of head, for q (default:8) + -h_k num of head, for k/v, 0 means equal to h (default:0) + if not equal to h, then this is GQA/MQA case + -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) + total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary + -s_k seqlen_k, 0 means equal to s (default:0) + -d head dim for q, k (default:128) + -d_v head dim for v, 0 means equal to d (default:0) + -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) + note when squant=1, this value will be modified by range_q/k + -range_q per-tensor quantization range of q. used if squant=1. (default:2) + -range_k per-tensor quantization range of k. used if squant=1. (default:2) + -range_v per-tensor quantization range of v. used if squant=1. (default:2) + -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) + -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2) + -squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) + 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, + scale_o according to range_q, range_k, range_v, range_p, range_o + -iperm permute input (default:1) + if true, will be b*h*s*d, else b*s*h*d + -operm permute output (default:1) + -bias add bias or not (default:0) + -prec data type. fp16/bf16/fp8/bf8 (default:fp16) + -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) + 't', top-left causal mask, 'b', bottom-r causal mask + 't:l,r', top-left sliding window attn(swa) with FA style left right size + 'b:l,r', bottom-r sliding window attn(swa) with FA style left right size + 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa + 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa + 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) + + -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) + -lse 0 not store lse, 1 store lse (default:0) + -kname if set to 1 will print kernel name (default:0) + -init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) +``` +Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. + +## support features +Currently we are still in rapid development stage, so more features/optimizations will be coming soon. + +### hdim +Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default) + +### group/batch mode +Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below) + +### MQA/GQA +By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers. + +### input/output permute, and `b*s*3*h*d` +If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`. + +### attention bias +Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. + +### lse +For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` + +### vlayout +We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts. + +### attention mask +we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right. +Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out. +![](misc/gamc.png) + +Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline. + +| mask case| cmdline | FA style | xformer style | +|----------|:-------------:|:-------------:|:-------------:| +| no mask | `-mask=0`(default) | | | +| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` | +| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` | +| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` | +| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` | + +Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right. + +### dropout +TBD + +## FP8 experimental support +As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx940/941/942 machine and ROCm 6.0+. + +Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp new file mode 100644 index 000000000..8ca4ff933 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -0,0 +1,716 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "fmha_fwd.hpp" +#include "ck_tile/host.hpp" +#include "mask.hpp" +#include "utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("mode", "0", "kernel mode. 0:batch, 1:group") + .insert("b", "2", "batch size") + .insert("h", "8", "num of head, for q") + .insert("h_k", + "0", + "num of head, for k/v, 0 means equal to h\n" + "if not equal to h, then this is GQA/MQA case") + .insert("s", + "3328", + "seqlen_q. if group-mode, means the average value of seqlen_q\n" + "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") + .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("d", "128", "head dim for q, k") + .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("scale_s", + "0", + "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" + "note when squant=1, this value will be modified by range_q/k") + .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") + .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") + .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") + .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") + .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") + .insert( + "squant", + "0", + "if using static quantization fusion or not. 0: original flow(not prefered)\n" + "1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p,\n" + "scale_o according to range_q, range_k, range_v, range_p, range_o") + .insert("iperm", + "1", + "permute input\n" + "if true, will be b*h*s*d, else b*s*h*d") + .insert("operm", "1", "permute output") + .insert("bias", "0", "add bias or not") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("mask", + "0", + "0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n" + "'t', top-left causal mask, 'b', bottom-r causal mask\n" + "'t:l,r', top-left sliding window attn(swa) with FA style left right size\n" + "'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n" + "'xt:window_size', xformer style masking from top-left, window_size negative is " + "causal, positive is swa\n" + "'xb:window_size', xformer style masking from bottom-r, window_size negative is " + "causal, positive is swa\n" + "'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for " + "now)") + .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") + .insert("lse", "0", "0 not store lse, 1 store lse") + .insert("kname", "0", "if set to 1 will print kernel name") + .insert( + "init", "1", "init method. 0:random int, 1:random float, 2:trig float, 3:quantization") + .insert("seed", + "11939", + "random seed used for initializing input tensors. 0 for " + "non-deterministic seed") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// different threshold for different dtype +template +auto get_elimit(int /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(int init_method) +{ + if(init_method == 0) + { + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); + } + else + { + double rtol = 3e-3; + double atol = 3e-3; + return ck_tile::make_tuple(rtol, atol); + } +} + +template <> +auto get_elimit(int init_method) +{ + if(init_method == 0) + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + auto mode = static_cast(arg_parser.get_uint32("mode")); + ck_tile::index_t batch = arg_parser.get_int("b"); + ck_tile::index_t nhead = arg_parser.get_int("h"); + ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); + if(nhead_k == 0) + nhead_k = nhead; + + if(nhead % nhead_k != 0) + { + std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl; + return false; + } + + ck_tile::index_t seqlen_q = arg_parser.get_int("s"); + ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); + if(seqlen_k == 0) + seqlen_k = seqlen_q; + ck_tile::index_t hdim_q = arg_parser.get_int("d"); + ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); + if(hdim_v == 0) + hdim_v = hdim_q; + + bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim + bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim + + float scale_s = arg_parser.get_float("scale_s"); + if(scale_s == .0f) + scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); // TODO: q ? v ? + + bool squant = arg_parser.get_bool("squant"); + if constexpr(!std::is_same_v) + { + if(squant) + { + std::cerr << "static quantization only support fp8 for now" << std::endl; + return false; + } + } + + float range_q = arg_parser.get_float("range_q"); + float range_k = arg_parser.get_float("range_k"); + float range_v = arg_parser.get_float("range_v"); + float range_p = arg_parser.get_float("range_p"); + float range_o = arg_parser.get_float("range_o"); + + float dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + + float scale_p = 1.f; + float scale_o = 1.f; + + if(squant) + { + scale_s = scale_s * (range_q / dtype_max) * (range_k / dtype_max); + scale_p = dtype_max / range_p; + // scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)] + scale_o = range_p * range_v / range_o / dtype_max; + } + + std::string vlayout = arg_parser.get_str("vlayout"); + bool use_bias = arg_parser.get_bool("bias"); + bool lse = arg_parser.get_bool("lse"); + + mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); + + int init_method = arg_parser.get_int("init"); + std::optional seed = arg_parser.get_uint32("seed"); + if(*seed == 0) + { + seed.reset(); + } + + int stream_warmup = arg_parser.get_int("warmup"); + int stream_repeat = arg_parser.get_int("repeat"); + bool kname = arg_parser.get_bool("kname"); + + ck_tile::stream_config stream_config{ + nullptr, true, /* log_level = */ (kname ? 1 : 0), stream_warmup, stream_repeat}; + + const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q); + const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k); + + using TypeConfig = FmhaFwdTypeConfig; + + using QDataType = typename TypeConfig::QDataType; + using KDataType = typename TypeConfig::KDataType; + using VDataType = typename TypeConfig::VDataType; + using BiasDataType = typename TypeConfig::BiasDataType; + using LSEDataType = typename TypeConfig::LSEDataType; + using SaccDataType = typename TypeConfig::SaccDataType; + using SMPLComputeDataType = typename TypeConfig::SMPLComputeDataType; + using PDataType = typename TypeConfig::PDataType; + using OaccDataType = typename TypeConfig::OaccDataType; + using ODataType = typename TypeConfig::ODataType; + + // accumulation numbers for performance evaluation + std::size_t flop = 0, num_byte = 0; + auto max_seqlen_q = + std::numeric_limits::min(); // we will use max seqlen to decide grid size + { + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + if(max_seqlen_q < real_seqlen_q) + { + max_seqlen_q = real_seqlen_q; + } + + flop += nhead * (static_cast(2) * real_seqlen_q * real_seqlen_k * hdim_q + + static_cast(2) * real_seqlen_q * hdim_v * real_seqlen_k); + + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k + + sizeof(ODataType) * real_seqlen_q * hdim_v); + } + } + + auto get_lengths = [&](bool permute, + ck_tile::index_t b /*batch*/, + ck_tile::index_t h /*nhead*/, + ck_tile::index_t s /*seqlen*/, + ck_tile::index_t d /*hdim*/) { + if(permute) + return std::array{b, h, s, d}; + else + return std::array{b, s, h, d}; + }; + + bool is_v_rowmajor = vlayout == std::string("r"); + + // host memory for storing all the tensor elements + const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1); + const ck_tile::index_t shape_seqlen_q = + (mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back()); + const ck_tile::index_t shape_seqlen_k = + (mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back()); + + ck_tile::HostTensor q_host( + get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); + ck_tile::HostTensor k_host( + get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); + ck_tile::HostTensor v_host( + is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) + : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); + // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host + // will not be used for verification at all (but will be copied to device anyway). + ck_tile::HostTensor bias_host( + use_bias + ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) + : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] + ck_tile::HostTensor lse_host( + lse ? std::array{shape_batch, nhead, shape_seqlen_q} + : std::array{1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor o_host( + get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); + + if(init_method == 0) + { + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + ck_tile::FillUniformDistributionIntegerValue{-2.f, 2.f, seed}(bias_host); + } + else if(init_method == 1) + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + } + else if(init_method == 2) + { + ck_tile::FillTrigValue{}(q_host); + ck_tile::FillTrigValue{}(k_host); + ck_tile::FillTrigValue{}(v_host); + ck_tile::FillTrigValue{}(bias_host); + } + else if(init_method == 3) // suitable for fp8 quantization + { + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(q_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(k_host); + ck_tile::FillUniformDistribution{-dtype_max, dtype_max, seed}(v_host); + + // bias_fp8 = qscale_bias * bias_fp32 + float qscale_bias = (dtype_max / range_q) * (dtype_max / range_k); + // Assume bias is in [-1.f, 1.f] in original fp32 + ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); + } + + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + bias_buf.ToDevice(bias_host.data()); + seqstart_q.ToDevice(seqstart_q_host.data()); + seqstart_k.ToDevice(seqstart_k_host.data()); + + // clang-format off + auto layout_str = [&](bool permute){ + if (permute) return std::string("bhsd"); + else return std::string("bshd"); + }; + auto io_layout = [&](bool iperm_, bool operm_) { + if (iperm_ == operm_) return layout_str(iperm_); + else return layout_str(iperm_) + std::string("-") + layout_str(operm_); + }; + // clang-format on + const std::string prec = arg_parser.get_str("prec"); + + std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch + << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s + << ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << vlayout << std::flush; + + auto fmha_traits = fmha_fwd_traits{hdim_q, + hdim_v, + data_type, + mode == mode_enum::group, + is_v_rowmajor, + mask.type, + use_bias, + lse, + squant}; + + auto p_compute_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::scales{scale_p}; + else + return ck_tile::identity{}; + }(); + + auto oacc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{scale_o}); + else + return ck_tile::identity{}; + }(); + + auto fmha_args = [&]() { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; + }(); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + return fmha_fwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale_s, + scale_p, + scale_o, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask.left, + mask.right, + static_cast(mask.type)}; + }(); + + float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); + + if(ave_time < 0) + { + std::cout << ", not supported yet" << std::flush << std::endl; + return false; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, " + << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec + << " GB/s" << std::flush; + + if(!do_validation) + { + std::cout << std::flush << std::endl; + return true; + } + + o_buf.FromDevice(o_host.data()); + lse_buf.FromDevice(lse_host.data()); + + bool pass = true; + + for(ck_tile::index_t wb = 0; wb < batch; ++wb) + { + const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb]; + const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb]; + + // adjust matrix index according to the mode + const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0); + const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]); + const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]); + + const auto v_host_ref_lengths = + std::array{nhead, hdim_v, real_seqlen_k}; + const auto v_host_ref_strides = + is_v_rowmajor + ? std::array{hdim_v * real_seqlen_k, 1, hdim_v} + : std::array{hdim_v * real_seqlen_k, real_seqlen_k, 1}; + + ck_tile::HostTensor q_host_ref({nhead, real_seqlen_q, hdim_q}); + ck_tile::HostTensor k_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor v_host_ref(v_host_ref_lengths, v_host_ref_strides); + ck_tile::HostTensor o_host_ref({nhead, real_seqlen_q, hdim_v}); + + ck_tile::HostTensor s_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor p_host_ref({nhead, real_seqlen_q, real_seqlen_k}); + ck_tile::HostTensor lse_host_ref({nhead, real_seqlen_q}); + + ck_tile::index_t nr = nhead / nhead_k; + + // clang-format off + // permute + if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); + else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); + + if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); + else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); + + if (is_v_rowmajor) { + // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); + // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); + } + else { + if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); + else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); + } + // clang-format on + + // reference + ck_tile::reference_batched_gemm( + q_host_ref, + k_host_ref, + s_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + ck_tile::scales(scale_s)); + + if(use_bias) + { + ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off + if(i_perm) + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2] + key_offset); }); + else + bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2] + key_offset); }); + // clang-format on + + // broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q, + // real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, bias_host_ref, s_host_ref); + } + + if(mask.type == mask_enum::no_mask) + { + ck_tile::reference_batched_masking( + s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k}); + } + else if(mask.type == mask_enum::window_generic) + { + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, mask.right, real_seqlen_q, real_seqlen_k)); + } + else + { + // if left window size is negative, means causal + // else means generic (for current batch) + if(mask.left < 0) + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + else + ck_tile::reference_batched_masking( + s_host_ref, + ck_tile::make_generic_attention_mask_from_lr_window( + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + mask.type == mask_enum::mask_top_left)); + } + if(lse) + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref); + } + else + { + ck_tile::reference_batched_softmax( + s_host_ref, p_host_ref, p_compute_element_func); + } + + ck_tile::reference_batched_gemm( + p_host_ref, + v_host_ref, + o_host_ref, + ck_tile::identity{}, + ck_tile::identity{}, + oacc_element_func); + + ck_tile::HostTensor o_host_result({nhead, real_seqlen_q, hdim_v}); + // clang-format off + // permute + if(o_perm) o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[0], idx[1] + query_offset, idx[2]); }); + else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b, idx[1] + query_offset, idx[0], idx[2]); }); + // clang-format on + + auto [rtol, atol] = get_elimit(init_method); + bool cur_pass = ck_tile::check_err( + o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + pass &= cur_pass; + if(!cur_pass) + { + std::cerr << "OUT mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + + if(lse) + { + ck_tile::HostTensor lse_host_result({nhead, real_seqlen_q}); + lse_host_result.ForEach([&](auto& self, auto idx) { + self(idx) = lse_host(b, idx[0], idx[1] + query_offset); + }); + + bool lse_pass = ck_tile::check_err(lse_host_result, + lse_host_ref, + "LSE Error: Incorrect results!", + rtol, + atol, + /* allow_infinity_ref = */ true); + + pass &= lse_pass; + if(!cur_pass) + { + std::cerr << "LSE mismatch found at batch: " << wb << std::endl + << "\tseqlen_q: " << real_seqlen_q << std::endl + << "\tseqlen_k: " << real_seqlen_k << std::endl + << "\tseqstart_q: " << seqstart_q_host << std::endl + << "\tseqstart_k: " << seqstart_k_host << std::endl; + + break; + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp8") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp new file mode 100644 index 000000000..9a82ab6b7 --- /dev/null +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -0,0 +1,269 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "mask.hpp" +#include + +template +struct FmhaFwdTypeConfig; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::half_t; + using KDataType = ck_tile::half_t; + using VDataType = ck_tile::half_t; + using BiasDataType = ck_tile::half_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::half_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::half_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf16_t; + using KDataType = ck_tile::bf16_t; + using VDataType = ck_tile::bf16_t; + using BiasDataType = ck_tile::bf16_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::fp8_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::bf8_t; + using KDataType = ck_tile::bf8_t; + using VDataType = ck_tile::bf8_t; + using BiasDataType = ck_tile::bf8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf8_t; +}; + +struct FmhaMasks +{ + using NoMask = ck_tile::GenericAttentionMask; + using GenericMask = ck_tile::GenericAttentionMask; + using CausalMask = ck_tile::GenericAttentionMask; +}; + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* lse_ptr; + void* o_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale_s; + float scale_p; + float scale_o; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + ck_tile::index_t window_size_left; + ck_tile::index_t window_size_right; + ck_tile::index_t mask_type; +}; + +template +auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q / args.nhead_k, + args.scale_s, + args.scale_p, + args.scale_o, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse, + args.batch_stride_o, + args.window_size_left, + args.window_size_right, + args.mask_type); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); +} + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct fmha_fwd_traits_ +{ + static constexpr ck_tile::index_t HDim = HDim_; + using DataType = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = kIsGroupMode_; + static constexpr ck_tile::index_t kM0 = kM0_; + static constexpr ck_tile::index_t kN0 = kN0_; + static constexpr ck_tile::index_t kK0 = kK0_; + static constexpr ck_tile::index_t kN1 = kN1_; + static constexpr ck_tile::index_t kK1 = kK1_; + static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_; + static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_; + static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_; + using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kStoreLse = kStoreLse_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr bool kPadS = kPadS_; + static constexpr bool kPadSK = kPadSK_; + static constexpr bool kPadD = kPadD_; + static constexpr bool kPadDv = kPadDv_; +}; + +template +float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args); + +// This is the public API, will be generated by script +struct fmha_fwd_traits +{ + int hdim_q; + int hdim_v; + std::string data_type; + bool is_group_mode; + bool is_v_rowmajor; + mask_enum mask_type; + bool has_bias; + bool has_lse; + bool do_fp8_static_quant; + // TODO: padding check is inside this api +}; +float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py new file mode 100644 index 000000000..56d699e5f --- /dev/null +++ b/example/ck_tile/01_fmha/generate.py @@ -0,0 +1,605 @@ +# SPDX-License-Identifier: MIT +# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +# generate kernel instances to speed up compilation + +import argparse +import itertools +from pathlib import Path +from typing import List, Optional, Tuple +from dataclasses import dataclass +import copy +import fnmatch + +DTYPE_MAP = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp8" : "ck_tile::fp8_t" +} + +DTYPE_BITS = { + "fp32": 32, + "fp16": 16, + "bf16": 16, + "fp8" : 8, + "bf8" : 8 +} + +MASK_IMPL = { + "generic" : "ck_tile::GenericAttentionMask", + "simplified" : "ck_tile::SimplifiedGenericAttentionMask" +} + +MASK_SIMPLIFIED_MAP = { + "s_no" : "ck_tile::SimplifiedGenericAttentionMask", + "s_mask" : "ck_tile::SimplifiedGenericAttentionMask", +} + +MASK_MAP = { + "no" : "FmhaMasks::NoMask", + "causal" : "FmhaMasks::CausalMask", + "generic" : "FmhaMasks::GenericMask" +} + +MODE_MAP = { + "batch" : "false", + "group" : "true" +} + +LAYOUT_MAP = { + "row" : "true", + "col" : "false" +} + +PIPELINE_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineQRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync", +} + +PIPELINE_ENUM_MAP = { + "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS", + "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC", +} + +BOOL_MAP = { + "t" : "true", + "f" : "false" +} + +DIRECTIONS = ["fwd"] +GEN_DIR = "" # in Cmake, have to generate files in same folder + +FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +// auto generated by generate.py +#include "fmha_fwd.hpp" +""" + +FMHA_FWD_KERNEL_BODY=""" +using fmha_dtype_{F_idx} = {F_dtype}; + +using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>; +using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>; +using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>; + +using fmha_shape_{F_idx} = ck_tile::TileFmhaShape; + +using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad}, + {F_skpad}, + {F_dpad}, + {F_dvpad}, + {F_bias}, + {F_lse}, + {F_squant}, + {F_occupancy}>; +using fmha_mask_{F_idx} = {F_mask}; + +using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem< + typename FmhaFwdTypeConfig::QDataType, + typename FmhaFwdTypeConfig::KDataType, + typename FmhaFwdTypeConfig::VDataType, + typename FmhaFwdTypeConfig::SaccDataType, + typename FmhaFwdTypeConfig::SMPLComputeDataType, + typename FmhaFwdTypeConfig::BiasDataType, + typename FmhaFwdTypeConfig::LSEDataType, + typename FmhaFwdTypeConfig::PDataType, + typename FmhaFwdTypeConfig::OaccDataType, + typename FmhaFwdTypeConfig::ODataType, + fmha_shape_{F_idx}, + {F_mode}, + fmha_mask_{F_idx}, + fmha_trait_{F_idx}>; + +using fmha_pipeline_{F_idx} = {F_pipeline}< + fmha_pipeline_problem_{F_idx}>; + +using fmha_epilogue_{F_idx} = + ck_tile::Default2DEpilogue::OaccDataType, + typename FmhaFwdTypeConfig<{F_dtype}>::ODataType, + {F_spad}, {F_dvpad}>>; + +using fmha_kernel_{F_idx} = + ck_tile::FmhaFwdKernel, + fmha_pipeline_{F_idx}, + fmha_epilogue_{F_idx}>; + +using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, + {F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + +#include + +template<> +float fmha_fwd_(const ck_tile::stream_config& s, fmha_fwd_args a) +{{ + using k_ = fmha_kernel_{F_idx}; + if(s.log_level_ > 0) + std::cout << ", " << k_::GetName() << std::flush; + auto [kargs, grids] = fmha_fwd_create_kargs_and_grids(a); + constexpr dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel(s, k_{{}}, grids, blocks, 0, kargs); +}} +""" + +FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp" +FMHA_FWD_API=""" +float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{ + float r = -1; +{F_dispatch} + return r; +}} +""" + +FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ +{F_hdim_case} + }} +""" +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ +{F_inner_dispatch} + }} +""" +MASK_CHECK_MAP = { + "no" : "t.mask_type == mask_enum::no_mask", + "causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right", + "generic" : "t.mask_type == mask_enum::window_generic", +} + +MASK_SIMPLIFIED_CHECK_MAP = { + "s_no" : "t.mask_type == mask_enum::no_mask", + "s_mask" : "t.mask_type != mask_enum::no_mask", +} + +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && + ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + return fmha_fwd_(s, a); + }} +""" + +def get_mask_map(mask : str): + if mask == "generic": + return MASK_MAP + elif mask == "simplified": + return MASK_SIMPLIFIED_MAP + else: + assert False + return None + +def get_mask_check_map(mask : str): + if mask == "generic": + return MASK_CHECK_MAP + elif mask == "simplified": + return MASK_SIMPLIFIED_CHECK_MAP + else: + assert False + return None + +@dataclass +class FmhaFwdApiTrait: + pipeline_tag : str + # sync with fmha_fwd_traits<>, to generate fallback calls + hdim : str + dtype : str # data type + mode : str # value from MODE_MAP + bm0 : int # tile size along q seqlen (block size) + bn0 : int # tile size along qk seqlen + bk0 : int # tile size along qk gemm unroll + bn1 : int # tile size along v head_dim + bk1 : int # tile size along kv gemm unroll + bk0blen : int + vlayout : str + mask : str + bias : str # true/false + lse : str # + squant : str # + spad : str + skpad : str + dpad : str + dvpad : str + + @property + def name(self) -> str: + return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\ + f'{self.vlayout}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}' + + @property + def scheck(self) -> str: + if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.spad == 't' : return 'true' # always support + else : return 'true' + elif self.pipeline_tag in ['qr']: + if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_q % {self.bm0} == 0' + else: assert False + + @property + def skcheck(self) -> str: + if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true + if self.pipeline_tag == 'qr_async': + if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k % {self.bn0} == 0' + elif self.pipeline_tag in ['qr', 'qr_fp8']: + if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.seqlen_k % {self.bn0} == 0' + else: assert False + + @property + def dcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dpad == 't': return f'a.hdim_q % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_q % {self.bk0blen} == 0' + else: assert False + + @property + def dvcheck(self) -> str: + if self.pipeline_tag == 'qr_async': + vec = int((32 * 4) / DTYPE_BITS[self.dtype]) + if self.dvpad == 't': return f'a.hdim_v % {vec} == 0' + else : assert False + elif self.pipeline_tag in ['qr']: + if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly) + else : return f'a.hdim_v % {self.bk0blen} == 0' + else: assert False + +@dataclass +class FmhaFwdPipeline: + tag : str + + F_vlayout : str # row/col + F_spad : str # true/false + F_skpad : str # + F_dpad : str # + F_dvpad : str # + F_bias : str # true/false + F_lse : str # + F_squant : str # + F_mask : str # value from MASK_MAP + + @property + def name(self) -> str: + def pad_name() -> str: + n = '' + if self.F_spad == 't': n += 's' + if self.F_skpad == 't' : n += 'sk' + if self.F_dpad == 't' : n += 'd' + if self.F_dvpad == 't' : n += 'dv' + if n != '' : n = 'p' + n + return n + pn = pad_name() + n = f'{self.tag}_v{self.F_vlayout[0]}' + if pn != '' : n += f'_{pn}' + if self.F_bias == 't' : n += '_bias' + if self.F_mask[0:2] == 's_': + if self.F_mask == 's_mask': n += f'_mask' + else: + if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}' + if self.F_lse == 't' : n += '_lse' + if self.F_squant == 't' : n += '_squant' + return n + +class FmhaFwdApiPool: + def __init__(self, mask_impl): + self.pool = dict() + self.mask_impl = mask_impl + + def register_traits(self, trait : FmhaFwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.pool.keys(): + self.pool[trait.dtype] = dict() + if trait.hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][trait.hdim] = list() + + self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], + F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, + F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, + F_hdim=hdim, F_dtype=DTYPE_MAP[dtype]) + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes) + +@dataclass +class FmhaFwdTileSize: + F_bm0 : int # tile size along q seqlen (block size) + F_bn0 : int # tile size along qk seqlen + F_bk0 : int # tile size along qk gemm unroll + F_bn1 : int # tile size along v head_dim + F_bk1 : int # tile size along kv gemm unroll + F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile) + F_rm : int # number of warps along q seqlen (block warps) + F_rn : int # number of warps along k seqlen(not used) + F_rk : int # number of warps along gemm-k(not used) + F_wm : int # warp size along m (warp size) + F_wn : int # warp size along n + F_wk : int # warp size along k + F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy + @property + def name(self) -> str: + return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\ + f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\ + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}") + +@dataclass +class FmhaFwdKernel: + direction : str + F_idx : int # this is not a tunable, but a counter to differentiate symbol + F_hdim : int # hdim + F_dtype : str # data type + F_mode : str # value from MODE_MAP + F_tile : FmhaFwdTileSize + F_pipeline : FmhaFwdPipeline + mask_impl : str + + @property + def template(self) -> str: + kernel_body = str() + return FMHA_FWD_KERNEL_HEADER + \ + FMHA_FWD_KERNEL_BODY.format( + F_idx = self.F_idx, + F_hdim = self.F_hdim, + F_dtype = DTYPE_MAP[self.F_dtype], + F_bm0 = self.F_tile.F_bm0, + F_bn0 = self.F_tile.F_bn0, + F_bk0 = self.F_tile.F_bk0, + F_bn1 = self.F_tile.F_bn1, + F_bk1 = self.F_tile.F_bk1, + F_bk0blen = self.F_tile.F_bk0blen, + F_rm = self.F_tile.F_rm, + F_rn = self.F_tile.F_rn, + F_rk = self.F_tile.F_rk, + F_wm = self.F_tile.F_wm, + F_wn = self.F_tile.F_wn, + F_wk = self.F_tile.F_wk, + F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout], + F_spad = BOOL_MAP[self.F_pipeline.F_spad], + F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], + F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], + F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], + F_bias = BOOL_MAP[self.F_pipeline.F_bias], + F_lse = BOOL_MAP[self.F_pipeline.F_lse], + F_squant = BOOL_MAP[self.F_pipeline.F_squant], + F_occupancy = self.F_tile.F_occupancy, + F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag], + F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask], + F_mode = MODE_MAP[self.F_mode], + F_pipeline = PIPELINE_MAP[self.F_pipeline.tag]) + + @property + def name(self) -> str: + # TODO: we don't encode idx here + return f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" +\ + self.F_tile.name + '_' + self.F_pipeline.name + + @property + def filename(self) -> str: + return self.name + ".cpp" + + def api_trait(self) -> FmhaFwdApiTrait: + return FmhaFwdApiTrait( + pipeline_tag=self.F_pipeline.tag, + hdim=str(self.F_hdim), + dtype=self.F_dtype, + mode=self.F_mode, + bm0=self.F_tile.F_bm0, + bn0=self.F_tile.F_bn0, + bk0=self.F_tile.F_bk0, + bn1=self.F_tile.F_bn1, + bk1=self.F_tile.F_bk1, + bk0blen=self.F_tile.F_bk0blen, + vlayout=self.F_pipeline.F_vlayout, + mask=self.F_pipeline.F_mask, + bias=self.F_pipeline.F_bias, + lse=self.F_pipeline.F_lse, + squant=self.F_pipeline.F_squant, + spad=self.F_pipeline.F_spad, + skpad=self.F_pipeline.F_skpad, + dpad=self.F_pipeline.F_dpad, + dvpad=self.F_pipeline.F_dvpad) + +# TODO: design a more practical way to do it +# this is current supported tile size per hdim +def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[dict]: + if direction == 'fwd': + if dtype == 'fp16' or dtype == 'bf16': + return { + '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1), + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1), + } + elif dtype == 'fp8' or dtype == 'bf8': + return { + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1) + } + else: + return None + else: + return None + +def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad + # support this in future + def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + # this function will populate a list possible pipelines + # TODO: the order of List matters! the later in this list will be also be checked later + # TODO: currently for qr pipeline, let 't' padding to appear later!! + # TODO: how to design this more generic? + squant = 't' if dtype == 'fp8' else 'f' + pipelines = [] + if dtype in ['fp16', 'bf16']: + for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): + if hdim == 256: + # if True: + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) + + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + else: + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, squant, mask)) + if receipt == 1: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim + elif dtype in ['fp8', 'bf8']: + # no need lse kernels + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) + else: + assert False + return pipelines + + gen = list() + api_pool = FmhaFwdApiPool(mask_impl) + + for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()): + d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype) + if d == None: + continue + #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): + for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): + tile = d[hdim_str] + hdim = int(hdim_str) + for pipeline in get_pipelines(dtype, hdim): + if mode == "group": + if pipeline.F_spad != 't' or pipeline.F_skpad != 't': + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + k = FmhaFwdKernel(direction=direction, + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + api_pool.register_traits(k.api_trait()) + gen.append(k) + + return (api_pool, gen) + +def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: + (autogen_dir / kernel.filename).write_text(kernel.template) + +def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: + (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) + +def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) / GEN_DIR + + output_dir.mkdir(parents=True, exist_ok=True) + api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + write_single_kernel(kernel, output_dir) + write_api(api_pool, output_dir) + +# list all the files that will be generated +def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: + assert output_file is not None + file_path = Path(output_file) + with file_path.open('a') as f: + _, kernels = get_blobs(kernel_filter, receipt, mask_impl) + for kernel in kernels: + f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") + f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate", + description="gen api for CK fmha kernel", + ) + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="write all the blobs into a directory" + ) + parser.add_argument( + "-l", + "--list_blobs", + required=False, + help="list all the kernels to a file" + ) + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) + + parser.add_argument( + "-m", + "--mask", + default="simplified", + required=False, + help="mask implementation, simplified/generic" + ) + + parser.add_argument( + "-r", + "--receipt", + default=0, + required=False, + help="codegen receipt. 0: generate only 8xhdim coverage\n" + \ + " 1: generate more instance to cover all hdim" + ) + + args = parser.parse_args() + if args.list_blobs is not None: + list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask) + else: + write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask) diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp new file mode 100644 index 000000000..56fc8b8b1 --- /dev/null +++ b/example/ck_tile/01_fmha/mask.hpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep this in sync with ck_tile::GenericAttentionMaskEnum +enum class mask_enum +{ + no_mask = 0, + mask_top_left, + mask_bottom_right, + window_generic, +}; + +struct mask_info +{ + mask_enum type; + ck_tile::index_t y, x; + ck_tile::index_t left, right; // FA style SWA left/right + + void serialize(std::ostream& os) const + { + if(type == mask_enum::no_mask) + os << "n"; + else if(type == mask_enum::mask_top_left) + os << "t(" << left << ":" << right << ")"; + else if(type == mask_enum::mask_bottom_right) + os << "b(" << left << ":" << right << ")"; + else + { + os << "g(" << y << ":" << x << ")"; + } + } + static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k) + { + ck_tile::index_t x_total = seqlen_k; + ck_tile::index_t y_total = seqlen_q; + mask_info tmp; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string t = str.substr(0, found_0); + std::string v = str.substr(found_0 + 1); + if(t == "xt" || t == "xb") + { + // xformer style sliding window attn from top-left + ck_tile::index_t window_size = atoi(v.c_str()); + ck_tile::index_t left_size = -1; + ck_tile::index_t right_size = 0; + if(window_size > 0) + { + left_size = window_size / 2; + right_size = window_size - 1 - left_size; + } + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + left_size, right_size, y_total, x_total, t == "xt"); + + tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right; + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = left_size; + tmp.right = right_size; + } + else + { + auto found_1 = v.find(","); + if(found_1 == std::string::npos) + { + printf("not supported value %s, %s\n", v.c_str(), str.c_str()); + assert(0); + } + tmp.type = mask_enum::window_generic; + ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str()); + ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str()); + // TODO: some validation + if(t == "t") + { + tmp.type = mask_enum::mask_top_left; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, true); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "b") + { + tmp.type = mask_enum::mask_bottom_right; + auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window( + v0, v1, y_total, x_total, false); + tmp.y = r.at(ck_tile::number<0>{}); + tmp.x = r.at(ck_tile::number<1>{}); + tmp.left = v0; + tmp.right = v1; + } + else if(t == "g") + { + tmp.y = v0; + tmp.x = v1; + tmp.left = v0; // TODO: don't use this? + tmp.right = v1; + } + else + { + printf("not supported type %s, %s\n", t.c_str(), str.c_str()); + assert(0); + } + } + } + else + { + auto set_causal_top_left = [&]() { + tmp.type = mask_enum::mask_top_left; + tmp.y = seqlen_q; + tmp.x = 1; + tmp.left = -1; + tmp.right = 0; + }; + auto set_causal_bottom_right = [&]() { + tmp.type = mask_enum::mask_bottom_right; + tmp.y = seqlen_q; + tmp.x = seqlen_k - seqlen_q + 1; + tmp.left = -1; + tmp.right = 0; + }; + if(str == "t") + set_causal_top_left(); + else if(str == "b") + set_causal_bottom_right(); + else + { + tmp.type = static_cast(atoi(str.c_str())); + if(tmp.type == mask_enum::mask_top_left) + { + set_causal_top_left(); + } + else if(tmp.type == mask_enum::mask_bottom_right) + { + set_causal_bottom_right(); + } + } + } + return tmp; + } + + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); +}; + +inline std::ostream& operator<<(std::ostream& os, const mask_info& mi) +{ + mi.serialize(os); + return os; +} diff --git a/example/ck_tile/01_fmha/misc/gamc.png b/example/ck_tile/01_fmha/misc/gamc.png new file mode 100644 index 0000000000000000000000000000000000000000..2c96951f30f99ff0345706c4f3031ab2a623ce1a GIT binary patch literal 30073 zcmeFZ2UOI_w=QY|6&g@UO;$vb0u2ZV2q;L--6TPB5+o=R8w3?;6c8jQ0hQQ9$r4%y zn;;-4&}3;uk~B#|gGhSSfHTfH^Z(y-*E;XMyVg6i)~s1W*RRsvwX5p;zFkjrw3MkR zm?;h%I6$SUat(gqz~Srz2M!4xI|Tl6lpw`@;DE;!)oWMvya#7fv~P+I-P?{}aWh#u zbqu4KOqcZeq0M8_Wbp~S5~_y3tU|(t!i=nzuWR1k{-t70JnBMOWE{i58yi}M`sBXB zBNPw%9L(U=oni%jHgR9nPd}647PuY#tFjrs zD)@qX2C~m_7US8WxArRkrU)Tt?4_ezPI7vFO#RDh1c6|{7Ubwg>2Dyx*y37#VsesK z$l+;jH-d25?Bk2{)EENZW^QQuikQ$?BOT&p-8LFIfbhZ_2sxy0Xx=3D6F%!TR6=c(YwEcNzo zes6J;@!t-;_lErmx-x-6fta(dB*dl4lY3-H&}(AIY2?!1hW{-KuXru_!%;-G<6r{K zhGx$n9MtYO{XP?mT&XF-7lXMvy*{t`sY;1thPQA6eruBu>lm)*%5bk*y&$rVJ+q?_ z*Qj=yf@ew53?l1!`HwoPrLDv#kgYf8gS3bBwLzwQTaxQJiVh}FNG>(|BtUvwv zXV6q=qu||*>#rwh#0z`B6qRPOKR!f^%z4~*;Svpb znhUUxF^-Wq>oo6Msm0HWYmC#_A)ys9V>zRr+S7T2rqGp4E7)=U6N_6EeR^GzD-@~f z4Hd3&PuXt~C-KMfHj>41Wn1As@Q*eFZxfp%sDAgin^|UZ@JB^1+RB89QmLwAhYT)V zuT}`8mH0!T@VbzoOoX`Ltfkn`{6x*-TG_XCo@xp!bCYWC+Ia3@ClpV;j+bUhkI@(S zBJ|0a`tDELA8FUJWWsRkcepxGzGo9gKN*gx3yu4Xl|<&k`Hz1*|M-J&g0)Aw&@F8T zBA2>rL&c!t&sx0ZJ?MO*Qruj!$D6(vJ2rcZgRB+V+KcXE%USMfQkTmk#@=Owu)eKU zZ9gCQl+yQ@a+=t{`5Vb;4x6Lc_n#cp2*Wp2D{w65U9Vqh!VMYPdlBI1mRLtYyezpj zCgrZ{1D0i~7bVG838W%T@Q^mzg-L`Xf?!(udgtmYjz#SOC6|4BgI-L*Y`!%H zEXo4ore4L{H)@TteB_4KpRU@K|13S^_mdNE%P}*YfDC-f-ldNhU}JuYGswGuc^S8r zCPbt@<;&f+TCfrT*=zIx=Yv-PG{KGI#?5y|sDrPv8;6Ola^T*=*ZK%6`fiiSpKXl# z6;9myId7Vi^@NW8x|^>4o7bnw%W;7L(xu;q&{_k<`ax}!MFtqX=J?1L&tK;%2BR2` z{SgA|Tbr(0>7D3yGb2C*Gtpup=8%=FAhIMVm^oGGI ztc^Exr19garuB)Xk_|!zncQUo^3`6US*w$&wHc>gUW-nwoBSTr{{4!@cgk#^&taGz ze(86n2s2;*vaZ2~^ql{*Wj%M4f*l(Q{rZScMbXurmfiSFMiE1#8`b6_kG1%uOTeWZxP7m-cIPmR zAjhmMqxLc&&|sA#E*UV5Eh z^zcs#-jk+MtT2M;)oI(sO!3} zy{A@z_+ZnsNrlc#8;LNv-XVAdzO=$*F4*u<0yErw4K^TBB-|eSO1;GeDu-F+kzl0S zc@qOgPN{onb7Pig0*AgM$J&~#gwU1lt!vx!BLO?zJ3c?k=k)wOmJK<+VO{VFo|gZv zJgb7qLgb2Tz|J>m!nQPNeQj5b+WC5z?J7R6cMs^VWuS566YGl;gLls*${5^LoA8Ck ziY=?n6(=G1#Xnq_5oe4jf){KDO3++pFvUFY_CL0ja~e96@?5rgO|AHZ20v}ZgpV@E9_gr{dsIUe2Bgr#(|6q%bK(BA;gHIVVgs z&b2pl?T#=L^<)mtz@yPV5SpDAYufJShHp_1$gkY+!{1<=q0+sTf(%$1Z7!1i@olbh zggW3ysrja3`;dg(PMw4+zSet@)A>OUO~B7ZA?p$U<%j@#&%uoDIK`+Gr?*AewVU)(QD>AGXXoR9NX)Z=0D_rU_aIp(PFOVn?H;0{TK~i zs9n|?Je0@bKeN$Rnpotw*!V%v+tI1{PNoxq5oHNBec~|?WZCA&4^?^*Vq;>JFqk`DdJc8fmL)a>zA&=66|mXZ=U>ym zs*aBs$|#98)nGzJ)^aq%OyD2HAQcU;a(!W5q80R(9={r9WOKfIZF}j(f1$Jrh72o9 z?ED;daaMoAB^g5{p z>itiXv6w~g{LCbpk@Ze|WnJpv<>otf0UqalmM5R~F0D3(I_l-#QGwV(l^FX)(kM_T zb2ta$pjj3C>rPKBbq16gzZB#B&?8yi^(q?ZJ*R16riUOLEpy)M=|Y92A^>Nkw(Es`D|NN8;utYQ)=dOQG;h4L(F@tWu z%WA(`IjiPJpZrx(jK2qj4awh)Lj`QAdIVLhl<8@Wspc;-uj@6&zs0pZDgU+RTI|&! zfg!y6)^)i*?tS$!UpcIWKg#)urDL8vq1+D9dtK2Ing>=O%V|lo6IN;h8YTNt<{7eJ zAjZRB&utdSRl*q`CBP}8^Y>jY!|ay0ZHhh}s@OOPS936GjOdeKhWEY5utQ4;TnF zI0Hq0LL)5arYCCWz zw&|&9V@45~4bd6H6#-|tnkmmKgiV84CM_zZ3^Oo&3}jfNEO=wlz$PNI?Dt|KpZB`k z_!yT5gYYNOL9Cf*bBz zavlDyfN0o@VzXo=T;iw>f1@n6-)7IfKqoj6icXd}{o>gAEHiG&a+hTU)mdQ=jPlV*0GU#pQ&emXES z{h}L@Ey1WvURlRQ4`0}1L!HccXj#EaYMg0)?4#D_`wHjizO z_b6gNuf}J{6R7YOP*r{?cK~TP^}s#R772rSpkumch~%!YFfMvXbRDT*7Erx1J($KE zSW=le&yax|SccxAmg6}V3C zidht>tBUGmSlkWUfDS z>i5Sr9ir7__Z(*YzGk~J8TQ4g^=G*ASA%eQWG%^E+`EU8k`8`J3&b+$K+ZK;E~1*w z8KUFWrK~d5-BU?^sTy@Z&QA{*Ad*HEa|+C>@v~Y!XK)??$C8=+sKer79AJ!U;|lPi zBcMVCE?2c=P^#CdRVwel+Ux~xaK8js(_te?d&A+-^ro++8#Rlw3ILy5M8968)J2z@{hHIKA$EimYAH)qB4jy8OMR zUHALsJft~4aR0mLVp3=W2Hkg&%*Vy}rmx2X3Oq+@q`vsv^mlw@Esh-owQCqeiu8g; zR1tqR@+}+;GhOw^&}64uVVgTnQjR@05wYrfNlCqN?~{Q>BcZ*jU%bkEXh|r9Hxvzg zj?UgH^mVqd3~yPw<}jH)DSX$e7mLRwAAwju4v8@@wcK4fdrL%%KPba&Ve?z3?{Jcd zwSWDUa$^Bv&LwT*(4| zR(P$XHCe*7EUrj|sCHWg5+#wTxp(yXEhxpx>!~r#i189x3ar|VipI5+6q7>w-|0l6 z|B72xqO1_dx=zrVo@35mf8LLUCSLS#A%3ylpedsNXGD`H6x2|l8Dp7uL*H&$#Vk_I zXNbL~LAX#gs=vza#p7qlnL@A|WTHqhcCYMSFH|rdhG⪼0mN|Vz2CejK)`kn5K(< z`5$`>cO3=+EeEIbS3KYC8U?y?{{OyVRAic+rnT}CzQPTId^&_WX#5B43gseBnmxkH zG|G+DzVF<9Jr!U;sFKKU1%$DWp)&j@*vhy1I%i10xU{SrE!)OVo;CO z@ZdR*o@(`+2C@p?p0GbFOE-!}dLT5!ZH#VRbQ!;ZeWq3M_*LagEA@EmigGe1=0q)Z zj7h9xQ`VzBUZzC%f-0t4cWTr7N55lvN@;hY+-os5DrBFR^~41^H-#(N+10~a_;uoJ za_LZVXt}k;;)gGV%iAuq@S@KiDEE}W`Suo@a>z>=1Fw)e+YOqd+NsKgF0PZ**UL3O zn6^L38ZfIi_pw84TZyiHzL6jC@}MEH!#>p9G(~nqZo%uUb^C*php);JwY1WFM_0MF z%-`b$_aI6&2Re1Q>uqntt>8Rb*J~dr#d5c94>@Zg+b4P~7vRN?DY_r+)06hxf>Pjj)A;#-DiuEVek6nbYfpKyS>njiygMqZ?`vzhzd6 z{xq|S8to%n(~zl}UaR%-60aoOSjZjnn->#mE`z3}qDCIJb(n+^k+&SmhwQVndD_#J&Z@@a zZc&^YGjg?HX)nefDUC$#kAMwvK^?Qz((Y)*-PWE|;tQ zK_OIxYt0D5H%e9h8Rby%#M$6pDSZMmTBcT(=*ZHQ@i%81hTZnQ?58)^__FSQ?j=|#qG#G*L)5#=@MOgM)3--Y z)R^wER57e{;7r#&N_gjkYF@hP1OOVixEr&pREjwB_MCSzOX@%c{ofYujCI^u14fsN zxy&aShOEW1V!JSucIV$4w_l1+|IRkywc8FI&Gpo47FSGO=y_-?rFjk4N$K~s@n|IU zG4RG_;=df9*809}F~?Lvs_H##p7%rdX1n*B>3T5|(%l6~|4CEHqbW9>!a>8RDt@NE z$tN$;n}}b^Fr(#7zdB`>g7@ZmKmY0TmcN9p_c`y=4CqRH%q(ir#d30_*lA0!F@uCN z--IeHJdn1#$Ppr_h=r*?Vt*-!24i_)BcXAz7KJZp=v%jKwYYpq{`s`b!qKU*#j`Bv zzGG^)S4D@uI2PNUo%-Y`Ir-|Yo=64>{gy&4#&LLbC`7O`^f2(!>y)rGam+ZQo7V$h zo1e6DROuYX+O*X+ky<-fW4XC9nD+U`#mzCe)wdL(9+||Ijv^Wt?95XkmQYLvzKd^U zqb6WyOM7Q(w66lODAVndH|ka^x0>rRTV`wE4qrt}2RzM`kX`gtaF5WRgH;dX}bu)v3)=gPr>1qwJxDW*rW4kNovc zgpuOlm~W=hF>ry>TA%b9E=`oXi{I`p+Mb+}+nD~$zsOV*Ml@DK55krkl>-Ml-i`iD zzZ1j8)P&7)5Z>(ofvR-XP=$O2G!x6L)SgYO#xEXz-JIaQT1fM*iaViWI=uIS{iP2@ zR}=%Rq(g42V(2BCnm0DmHkRdn_%V!drSJ?(7Br789aWycC-pAoODfy`K-DBoW>`sB zX=?{0T?L~A&lUf$^mTR8dT)GiY=b9XC~el{Y-4v%YoVQ_pskGyV=)~iRl4XW6HO9m{0Y__ zQ;aua^=y8!W!})PeCr`@rkyeTlA!E!#0+kuXopxAutBr486@O)W^>dY{}Y!~DdfM^ zC%gF_>xK8>AtC8?#I1>e4~B1lSlPu3)ormDJ3rBQGU$y5)WTD%bj!~}4rUzzi4C=Q zD?V}NCYjIA&8cOYohrgU%yGtn-%C2a734=dzv{o4@SdG}GjKqqY_ZupaQNAp<)s<& zHN%yJ^L^=#DagE+WjjC9?6NgQi~PoD0&cBV*8XRFnHD0dueNM|MD_gjluBj5ZFKF_ zE1M%{0^TnAPyc7UJByhV84Ov{&M+Ilz3Uyo?2F_g@iDVp)8^beMprE*PnUvi5kTEA zNX!x`6p5i)9vEliF>1DS2+a66xCSNDTmi`mwH90eOx$%E>!$#( zC@0|+yb$YhlHr~q1wIaNBh7GS9L3t%>~*h64B6@e6E@qf}2U>&rIQ(tJrIm*(ywYpG7?PYgF8^A@es@ zu~R$#0XugO?KN9)MhPH?zmV=<^(ke(!kyo%$%xgeSE0lbL``%-$W?S@*krm@@E&~8 zIvf!V5MGu2h5#~Q_2A$CQY=6`lz5^h^q|Oty-kXA$$y-UU-Yh>iHaG!R3msWL;j(H z)gQ2^-wQA)eKIQ2JWKd+n>QA4;!Y)iM+NTnCpzPlJZgauBK75e6Kw_{JfKqq{;^UO zj2j6H28?@{CMn82hLh0nee}ez5JWV+kb9`&mN;Mo-M^5U0aXGUmDJkCaB_OnEYbbb zG`c%E^laeJE;PGMjXefXKa<~~*_vG#kvciP1Ir^(sIHWP8M_m(M0jQCvY`%e&i`JJ z8tw|PGW-`RR3yOBml%J8d@&)bMFR|iiyxNz-wCUNLhI6yvnpg*C*9u=@Ol10FF_V6 z>IfkYdm-bw3Q<J=(1g)$1c>$keN&R5)g*@`<~e@1xUQTT6C^3&6yED7kS z8W_a-QLFeel2T#*L5|nRJwrpqNv01oeMRMQ5)S#5M=YQuN<32>0M2J$M?1q?4nnNK ziDWlI!HWbtnX#m^3Saqpjs^5Vic_>2wFQ(!j*}R=XGo7Hy-+0*ZH^{E;m2lelJ(pU zV4>jDvKzy2dFo|NFcQJ1<}3>|0~S0(8BvF!O{ZgDAxOY|K}S>oL)FU?De0`Y07E`z z5Sf4k%n4M<7!-VH>14=vk%*-Efp}6`*QRqlg$>aSfL5El$Lb`=8gjM(^tT=+ro#s# zR!2xORs5te)L#2&T0nb_710JtX5<~)EeRNvJ;EU9H*h<9IP~mCmQ?dWnzZNm?8{$W zt!+I#!~Pdg_NWxL=EC<&%Wdttguns`0|WbJqhxR#Y~hz?nj!Io=zYeeT|WdmqmS21 zSj9S5qH!604Z?uHvknZ2A%Wd$Nx#5uGexg_FDYitcO&8ve%Q+>Ij2!5-1Pv&`gN;d z6KO60MYB%{#lBnUYwI}(!l>|ATEN*yTS@IlQCG>ZDl=f#VPO;dyN(K`TA*Za(p#N= zO*)9WhNA;cB%>KcB9{;Rzhl&0y?`+}YqPGt?VFOJnMy^&iS}d($w4dSh=~Izw9>gW z;3{O!S2G{JgK5aTSR zHy%hNA>R8EqXxu+_hr0?YU`cnrtt^hXF${fA?Wu4kh50#jQi?O01n5%`7+G*wXfi> zs|Xd$zH}6*hi9{&?E{J=DRIB{8KV3?O#b3Zahg0KXkYbz<$;wP$OPC|96vbJ=nu$? zKide|qJsOH6x3{q{7h8$`Uf~8n=EB*r!9$tv07>o9QH0rGyF}s`8@FwL@djM$ zJnL<`_6tUea$LI|4WaZ(gE{>Pm8X;_oE#y6duAiT&OF`$x0$8*6T1OU#$+e;Z6$Tm zwFx_#Zd?<6l7uhLS=O%P{sY9(qtOButJOb;ns`KA48{9q9a_5R_%ycdZ2R=SIn$K)wn9!6&G`Pc^YW zr0O)7!j39G7R5|$eet{OZ*vw5!HVPDCDYWoGdL^56HM;9omZvSLb?`poFSs9Ul?97 zJ`U;UE_bV(!|Z<-;R7CSqx*KvYiD~SZTlm($T|D#g3X|_*bTGOZp|JFH!sR>L?}zt zuSjUl#ts0CSnpQCfZLQdnDyu^cy#P@8OT(#WRSDo49p-J2?74{kHhsgm$KjxcapDiN6*c?O_LTYHpl zhS^yFZrb_W@m;=g`@bYmRRsaUrcqmV^GDB74$Ri-&?rs7YN&aE&&K4`&e~*qGNX`L zhQRBequ?%NJ3HjMy}0n9qts)teKnr$xXL!l6IwgAR`M^w1<+bZHf~u5UE~Ug^hB2P;%k;&(c&9 z;S#C1t4-}FwU>@}-5hmo-;OghqJBRMq{V;-fE>=e+#3FOU@z#9BE!GYrk=D(==}2C z%-Iih^X7WT=$6CmbR;QA+vwWa_w0zQH&amjUCbKkv=^|V=z%oVc;=|S)tKNpatPus z(4_9leMCP1Z3fsD4t%TM-xzMkKm1eh17xd!#%@2DXdR+s1K*^b;n?}nY|JlOyD?u= z&0jr6v8y-R#;oDn`WTA&t(2{7n7PeONLMQK3jq;*62zzXi6InReRR0s+mj1BV?u9N zhaB3}@r~f_%Mc8`OLbP#%mO;haNT z=;Fh2xX~q~ss>8B>ZcEJ-U?UUJ14)j0qvsDW|{ps79Y4N`so8JiN?>QrmsFGHQCIy zcVIU!0b3YWujo}GQaP{#BZaB@&i?f>^9oQIwPia$Vy~MOl-O|u!PKt<=xGQ!o&6bb zt&Nw^YyYm3Kuv*cF{r{;IVRq@_lO*sJ-Wb&~ zOLD93-3_9^#6q)>^jl2S{`1b2J-4A9wUIB9pCehiJHOa2e|=bN#+&@_zKp%RjKpYc zK-tQkc6`%R#gK`UhYZu*^%5p2`IIUZds(ANvfh|MzZO97Ll79Xx5aqrY~m$9Cm|Fv zyLf<_4=r(d$g8dCN^!E)yr2>ZocUJqKFKh-eX`n!&%bgQ5^ITwd2}5nEy6!M=4pxpMr=_a94lR$ukFNMlDY(`v^$I_dGD zw{#${5^6#tNw5z%fKMiHk;JgLNp&D)wMro9^LvZ4EF)juM0^+~x!fa?tKHRzA_}*Eb!fpI=V0gU7IgIiM2+m<&Vz$BCcR^H@l(pz7shkjbM0B# z&>UHSYk<`i?%Qm%3xz;_c|!sJPohFGPxLM6vK8^W1z9)1H%B&&tRK8XrGp z^jt}J2lFJ2bOfmJ1zo~pwtZ&9ym_~EW|p28V;0X~mC z93@kH+yU}(lu}`_4*DjKBvoKI6@uF){^iMlC&Sm8TjVSHu4e3GzgC;ZYn6E!gC z`*B-p^XV5aKY7>SBxa|E`1$HuF}?sysp8a1!)l?sn+;`$cUZKP zSBe z-d^4!PQX5MS4wo?q5}k?P36W%%)}R4t5+rvItjgNR`&>| zvWEuPD%^$M^6+PA=RWQ>jB`IKCp3H#-poWpOKz;+giFT%p(6m|+)GpgF7MP3`}H00 zHyl3JY)PWDrlXgWsxON}MnRQzz)1xoJ6U?1yJ*$-cu2CpQALLG(N|N}NfMoxZhSIw^$S;UBWM<$ zsjX-6r`lXWpeo0i6lT|{(Wo$f<(gB8v#>oVLI&U<8p>(ASlpccl+p&2Azzg3+8dZTGX)k*p7*bQ`WX#(y=fex)-)n@kog%ybus`zw$< zLtd@7U_2{s*57GOf$g{5-6{%bgoE~31V%c@zyA0iL;Mk*MsW7EF*ouyam$&tiHA#I zV}@JwSD~oRb0K13Z2V9DfHhhL^H*DA4mjPn<>{5w!{1vQo~4|Up18O5(?iU;@{3{Jtp!Z5Rsj6(qSMk#{fFb)X{AAbbcUm(=2F z+!Xr1!c5Bf?b|b7HsE2e#6b`>6JS6UPbOR=w0CYwu_$GM7+T@f+67-gO{F*ozE8Pz zx7d*VVDY&uk~FQDTOwljbQ@q4L~v+D~)HP2HB(mQjJwV`T{i$EXZEYkN_t zx!V$1^SlyN^2zR+#Wd^VUFtww4lWbXS9uIC;}Qh3YH< z>>wuMwpRg!%e+#!(4;15Ur$^2^}6HRM`OWp-C7;xhoml zqBjd-x2{mi%tKCzKF!PPp4tFcM|H2)ME`-XKo_0+O6B#mpN$Jvtp?gLFFDhKP=2O_ z-+_mSFtm-p`;UspNv6n$R;@Rxcla#z+Tp0!bho!nGbYHqb}51OIs)s3iiJr zV#Bw?)G!weqzqlR1ce?SjSz7mUD37sn(MpW6&~feELZi_hEYHMc_AX>##MTMPnXTG3FF+_~#{#&WBqP~W0^LZff3sf7M&4V;~v^#fai zd$Q2SsfgR*<+gQagj1ygreR30Xfxow6g7K zMZ4f+HpkAA$w=Cxx}SN`LID$@i#3I>|K$i%M!$qFvmHGt_B`{I;S%Dj*|&BzD>J34 zbqc*p>>HtW2vwsID*;;9Pezr4E2=dun3D!Knx@}}4EKuBO2fqYIIHnJ2a?2HEuEdw z>bZ@HiY8Q8Rc}FpzJ#>to9l@SuO}ulNPO7Gf>s%x0{e!umY!FZm)>gY{0eg zL()gAJrLk{16G6v!Q&{o(C`_RW~QTecw;i-&OQq5Po&dByfW>@CF|OH+*7@l;%LZo zH-GM_BJHVYVu?rV#`jUrohnA=j~&mEh)IZjj=;T^%yB;>zf3vh6zte&y4-0q9ONjy zOr_T4JcC(M+e+iYEsyt1?Xqub70l>^E~8_bAYs#viY^f;n%mG`2Q^YOUu^ z|Mw8a3C+#p6#IKYdZgZG6pp2?Tt05FS3olUwsAyf3-5XBsrrnFiR_^r_dnz+wWR;n zVje8aKcYu%Z_V66dr3BWh^Ys0_B%yR=na&yGKI&?FXb&GY2mKou{{n4wjpL>i5hy{ zq4Cu{9rmU820MoQnWrkDap!ej_I6|w@jv9AExX1%OF$l;zIlL{OXna=9 zu)u<1VNxLWttJO<(d~z~%14Vlj&~+*MYFOBjA2_Qx!@FPRn}$x8fB-idObh zRDi^1yU_^v;n=``2)28aPsn}TjP{N1W)(jnE6g5$*4JGdaRKG+h!jcXZS;?1|JbHjk4Nd1M9OWBk4JYGg;E3KEPRrZNC|#%$J`3Q3 z#gBmc#Vz8FM(&lyuh#IcENshu%Hl707niiQRVe}#T&t!lo+t^dMm zQ5e0)F3h%BLxU{Goz>t%F$6v!6!fq2y*n=JYXff0xas!B&I+Kd;acnWbgqXb>RpiK zcF46z(BS8GYIOujHQe;H;b_OXxtz; zcaW)uIG{ehlVfcV#Xk}!0W$CK`s^t}!M^!^Itc|X)6~Z7L^eG2g%(o1x ze5t22>IToHKub)sK%#R&IR{EmXRg&?R=yG$HX}Eof=>+aEjq248rf%iZZAB&#O9II z;tM6YDI`7nUsoKX!9)TT=ZRcBj%%*2241ke_=f;q6IB6b?S)m-ZW}f_!q$jMI4Oyd zY>nx_IqV%)+de_=;8wtymI&RjjX3@}(*hdX2D5bqwsd3U_S_wUJ0)!mm{#FNSQt@$ zKt_>qHN|L3$7gNWt5oRuLc>quOrPW+e0ACvkf8ccEN#Nq#t)F3_$X8<(HaUoVHn{K zuDg*)kNcJ}lFb?K^Z^^n3rRl=N&g@WK{K|_^Hp0g0?DG(AF$TvBuAIL57tTNL6q(! z6osR-U86I79jcpQ#%2Mv?I5`+>mOhZd;^u-g#eJdgr#*?ed`nbA*D2LzERiutnEQh z;aqNUa?<*Q?RmE!{nYGXsmEun z^x8qN958E)v4)@FhuZl9lGL)yBE%9HCO5kw?eW-K4#!!W{>B_6NOD2k58-Rh-~=vp z5D=2soAtXj$EYUmrwRc!HNB&5x-_kv127r-#*^uaLcNojmQ3HGbc`{pbMsSXE$1;a^k>A%lk>_9dU zt7BL*>!&%7lZj6~0}z>waMh*&|Ll zADz;~s5YAhG9M-jLDMyp7?54^H4_f3{o>e*Uqo$8)Gobx1f((T=hw9Cvag4Df6=^| zHoN5P!9Ph^xCn{64LR^9XUmsZBEcT9Wmq~ke{{L;z2_&U>^hT7O4fHy%@fC6wz`f#AiN_35(gxOwQ9=Q8)`m<6@6fZL!mT9&#Gk#{=;>4St`80G*8<~Vrx?3>H#O00t-C~;p z9ho-EiAOCogKpBF<8&Zdz^~pMYcjNxW-7nrc}-~=r1AcHj2+Bqx+v1A*Br`oOIo5@tLyd_?O2a%j_xP zpCfO}QRKzolH;DAO+GjAQ(c~mYU0!SY>;<`X-`N#4__3|BO=62dU&rd?b&_{3)@?3 za$9M^=0jSxviD!0+Rle`PhZqrL|uPQJGheP!jK6{`g zPC1GpTxat>8An^+6xBR$LgW9!WVXBQk+RAl)!iUUoUiB!8C}el^4F@V}s%U2>4s#sfV+4yQTFQvTtS z;GXs|InNq%CRY{hbtf=L@j}W}OYjiwgxM|MmH9WH{Hmfvwj?=2b$oqArmGabsb(MX ze@*S5KHI@y@RD9JlVSZLvx2E`>?^V3J)&vKtaQR^Klt-*SjP{Y3aO`qyC%kv>~Ri7Gh|0r|TMtSTv$cjOCi@5HEiytlmr4>W)@Exli#yiXFlvfEU5x1pYDyGW? zit-b=!#;H!tzxWWzFYTs!G}ZesPa@@Re}d@J+tnO#M!U9=Z0&X-i9q145f{nJG0p+ zA{cC)LlK&JQ}Y$^hcLcQ+Biecw2I9|otGRabZp_<#JcS@*PYR{y6-_kkK{H#n52Or zG#2d)26OoT=#XC6W#B8{*0Bn$0~PWIYL#D)Gr3ne*a%kJ6N2YCOOr{jK>aO;B^DyS zQ|~msDjeEU;3!MY+)-V>MK3eiO5`ntyJj96Df@8j{buDO&Lnw)SX$TiOu$c&-6JN= zeD6Qyv?N$0^c<^jkk{{0GjQosCzKx(*|B`_7%b!V?3IA;o$>geX|A>A$iI?Q_TOXb zH8&3t9FeS*dnzLng}X&#vyHuSQu7QkNqZf%MHGq-$ba#g^%+Sb0pvWrQ0}8DnAIzj zauB-CUE5*?I;d0{d89fF!Rreo<&!X$gx2o#%)Z&E-xNFWLhbvAA^-l~Z|&33yr&zs zLII^*{HUerd)707?~IDr=mNMVNiQ?Sg$Uw7nu{O*!rP5P7Zdm38j$*8stF+8$86Pk zrIt%XBQXFiS6{mYK|H2F8y>N-BLkZ)f;V7!ZD*0ud}FR`f`rxze~6PYKmj>4Blg>B zf)RLl;24xn=tT@zxGu^d2Q2OFTKrLmHWFHDl_21o4*^eqkl+^KVtx-540tiGK=$2T zM*WKs%q_NmeK?v>#{LWSklzD&v+(G?d(^2?b7oFzrN-Kp7s_gvrM<|@?|RbHRn0eb zkT+8DW`gsGTCaw!buNItzB2u`hA|E4s#DV33ewYefnIBsT|F~lW%8Zx{PN9D zH6sgx?4OF0REIOCD1Tt9Bp7o+y3z?BiuXBpzsauqSQDO3eh}eT?KJH`KfYWdB^Jo? z=gj=W0GpoSSi=0O>$QXISTwQ`ncSGk3uIc=?E_X%mClNDSe}I1A20K|C+W#9Pg53} z(L&d7)DWalw_~->v;_8Ng!Uf^E$hFfQ+ppy0hGgRLDpSJfC%lH2i&TamuH-y2yhDz z2{54w_CK(~!+TBDc`bQ7e|oBoltO*UWLE!Oe2YvjNyZm7TvD-FyjU3DxmZ0t9tGh9pzjTTC+r{t zz(>0__VeJ)l;rev5AiTbbRWy8J9u@j*>i>9-_M)RsIKOF0 zB}(+J$L3qu7T!LucXi1b2MJ-+QG5BXd`F2ed}!dIGb+}m%F=Z`HowZWNOkd$8Xw5} z^`r||ow*A19ww_sOhAyk%S3?W>h5V3vCrsY@S;i8J)9msVGh%JJ9;?kJb*J6v=u7N zhANt07%7t0**|$x2TnTRP>`EFgHx)!2{(U!!wc_&nRGy;Elts!K|K9yHNsuPC;J9q zjTs>GV74xa%}Xw!j{Jb$@$HR@hJB)c{z}S#YS0&JICX4smI=Pr$6C{b;vkBT!9POZ z8t_XGfPj$h6&%Q##9%j1Q&7ZE(KUe;1eG}zt&o`{d`fzdLn1J2mrET$A{tFuSs0~+QG zI8k4^93W+a|MbVbrFr6yS};!f>y}eSk}}oPQHwxo@Sdp?xa0g8Z91)(NWkqL58yXZ>Zr#6xuf5(+kq9(nVS2*9`<1m&`iid@_Y8-@Q-( zZ(y^4gU`PjBu|l+*-wRvg5I(8WMs`p{Km~UR8Pm0zJ9+zTxAQmY$P=8((TW?d+qs0 z(_IiJnq8DHWix0(N@SUN#Vt6j*~mtEz5Osg`Iw1X$5G`sZ%%I%;oXxF03Pr=uIBj! zOJefIZQoaBsT@M&{(&EdXewR)pW?1O9?G_DSC(X}>_W1nkTtuaQi?{N#=b<1E&IMq zmO_?fuZ)mHG?o~$8?x_P3{A{f%FfuDknOt$PtWsy@B6&ppWpn&?>F~7=Y3u0bzbLn z9>;my8Fc3PmBW~7cy*0#R=y}=kmlXw-y)38W`cz zcJD$ui%L1F5r_`=&3sIAqd6M?UuQ&c0c+of{H$J1!9+>%^pAh?U4TNC{2xyCU)>vW z2dG$m36NGK_=^Y79eU5OwdN3h(2a_8Sv=XkSG$1w%x(m>~*f zEu~3dEQi@s67!8&@e-}OS&vOBc5gZMJMe>2X60galHZswp9CPkb@j4@5G=B_@3SS6 z1%oCs2T+XEKN{p3jXfAaj;phD0fA;RbxX<=ORCt^Ik%u9=3mfE=k*V+3Fy$JCGw~ zw~mEiskHVdk7U}S3{_nE0N(_|Kmkwc%R=2X7d~6mb)nif!!Vfx&sU9s*Qee`M2eq1 zpr3>w&kzB7W^?<)Zn8{(&5uc5kw}f-`bMkET5IW>l`h|lma@9AO=-G$PtfowX?Tav za?Wi}P%@6@;Lf;YG`}* zX;VC7j?>-gjzEqjQ!z2kA|2!&cSHWU>rS@hQZxSz5cXgFWv`$zQI!FqSSR=TNke=> z3SKMBWi&)&LIj)Yt|?ID?&UnEsf@R_DC0cO4YY_dnz5w8v8NaFc+%{n8@&!IDiRxE4ibU+Ic z5Kd_32u%PYBx%P<+9N3g9AfEZdB@N4fjv56r7t>QBX3c&&U72onozUSxK8ITCm={l z;kjTeWNxboItN zt&4f0-4bdb>wG2vId;HlFyoIT%1mleR$bLtBMVp#9onA46DR&S%l}E^sxBB0pqg70 zO@y!A&=)PMQmOO3BK@`|u1!qmmfJE;{o+9P*LxZjC?)L-6El!#o2Yw>3%IbUx4S4Bm9VOy zTPdtxjTJaPtq(!r??=_Mz|eTZalc!$FH}Kneoq`SUyx15pnD-a5P^(9b|#yj7reHq)e@ov{F z=W=w^YT}7(j4kqJvH!Y^nOU(_)Ti^D9xylhxP-1HvT|vvHY1^sua#^ZQ4kiBpY^X^ z5PaOi7zbwD^eDYxmQ1B!X}l%vd0R^xH6tUt#$Q)BdJx$h zhTbpX-53WKo##Cgb;ZD0#Jh6UHLT`vHFTCa2}&| z(fwO6g*jh>;YkCgI4!3io+lsV;i)B6o;@L*tU&Vh6zt|wX6hh?01%-$DT-#YWyYcK zjhsOjSf_{uoYj$d{f?XIy|;|N!~sHzVsV466Gb9V(>2ig2m$ruzg95-Qc)x{OpCUj zF}zlxg;s(Eq%oGzEkEvQh#z$5qStI@p8&{A2Q{!VrzmQ{vyFP zH}NYS!Yryk_q(`m7D5InGbCmiOSvl9q@ttThae;2z+A$w9&bI611 zqttCO?UTfKxV%8CHc!L!?;G+EzlP1cYHm6fv`PNJ<14x)ans#Mfkb?Kr$6s7#^&J5 zSzH;cZhWOk1}X*wpM`?i{;HT~m~ik&9h|Owb1d_Fnss-qqo#|cO57=MA6?EXh!AZD zRsYSbCQz!{S2r~AT?#$wB>H*f(;iF}`NGXpnp?$PxuS;SrV>rSY_bEVW53~-950jf z=44h=wc6MgDd7npWrs2KP200$6q={8&*9b)IBtGrM1#|V-LcVDyGG1HRRSxWk*+um zZt7yWZq9<=K=O4wWkhbl5mMk5!{sLz8Su23+7Db67Nup;MXzGl^z^d2*J8U&XV_~r z1+u!)&Sd8Atk{>%5@!jbasmdvkk3FiF%TdlC;bwsX~AhQvZ-v89d=Gk5uJ77+o#T> zZP8cb39X=_>E}Gxgd9~^CGi^ZRoebdE4s+9$4iSeA9l7AyX6K9Qo4>y3Su3Zqnr1svS3Q^e+SN8c3*b zE*a^KU@-37;3QMNosBI#M_96(&r>&}=(gXZ;>`Z?&SZfDNf$HJQX<^7Z_(%d+SCy} z+~jJNwLsd7N;<@vcW#OOu2_Z%cX(w4bEpfXa33S)2e7j_{2~Qg)AxQ33*^~h{s})U zeRI0F*HxV{W=D9-lD75O`P;=qwsOYC$PXo>tuw0|%a_V2eeCJId_8LPG)!?_g(DL% z_j1qK)3D;MGZ%kdE!bWXdVW#w`><@*+L%WO(?FMVm@WgHb%%GMm|Cx9OMKX%EJWZwU@95C(|9(T@7hM zPox&Vk{d8mt*xbc?>bmcRSXgnNdwVuoiceP`Eip49|#h0vp;I6nw0ry&XEs>7FAjf zn6j?}F|qdA4Y`~*{Yk?YGw=)`ujfaeFn>DJ$+QnE!CYt2e}Ka$N`5PKuN6gYPKY`$&Tg+I!z1_v&Cw){7}x~HoD02gO2j@l>I8Q zRW^2mX6((}Kmp7oN#4+XrAW~JH1vAbZO3O(Zf!@Z~(}gQzmRpEkE%0Fx`_e95_BAa3ObqysgF0 z5H{E)d|poJve%cVT_Jf@`2!>6!u8Q)eX;sZKl*y8T>cnx)phyqASo;|BS0XSFD#R@ z3@N2q*CA!aH+b!D+76LJ^9ctA0oe;a&=btmn}GZz-jsnqQ9`foVgtd@iCnsko|-Kw z6KWIk4*jf{t0nhzk^T2|86u$LI!Z?zvU9ovw826EFw-2KLh_H)C?LtfnUq z(=<>H@$<{)1$vzqNz3OqQ%He(*+)R$oTth#!Qm*H`2n8+=tiqX0kw{|Vp@;Qo~fbR zbLz)vu#NZSeL2u5MD#|A}RV zX#L~Z6C5qK&_}#jIpc-2cp5B1m-Y!E^kbyW2rBkouz^w%n*RwrW=5Mfl?0Z@cb|T- zcMl*idUi4z5mVVEDbyeI<%;Soa~XIV?L=xG0mWbv*+CrQNi_mC*ciA&@J~huVbduY zyQC<}_{m0I@n*bL2J7CH$w)eJzgIItu?vOu2L zgxJvImx+mm=X|-heq6);paK`2J%9x2u?Cg&lUOQo%SnnFQ5$99S zUfc8tR%6fi%~Gk%u}}+N;b7auA9b3iR)w7QS|N9SwI1&wSia5*%IVvlBFO?4=0+lA zW2YNUFac!>E?(zRrPbA5b4xjWS#t(w?gUW2BFzuZ)vI^r@{{UIUwjRdYM+rj1*y<* zpRFFOoBi2kOM+_-->pXVnMMV}M?TiOh(L@pESWW(<|N1d4Et=?rInL!WZ!$!PzcbR zspi$5B)v}yzR+c2LncB(wdfd$;2&>rF9%o`&@6gZ)Ve!v9Hk8X%%dtv3VJKiuUEJqrZlM?Qs zW*sD+)ZQvFx1aU6xed*vy}&!<-(5M3-`@hHuc_2cWSTf#%_6duSHD*Bdb3`*zX{Nu zOGHUR+l-Y5=DJ<H26-_*CX=& z?m{)er#GZCl)XkflE`{yHflJHg6q`8l?P~?R}~vnUEQ+E{o{Pb>)O!Vt>73CkzV|~ z)t1s&m*bR*NVR>QJA2NyA(1iQNQo=qq*SM{6~1G%>gU+cx+J$bUA%8OZ7yEv%0*&R zScFfH?AYVUMk8o`Ro!(cO9E`aG#W`c$bI8I)wT=6-(4HXgaiIMZ1JfdfOtddzwcVkhH&^{_#2CD%7usUhEoZr4q`2io@qB{+ z%^{8tS6w05kQ-KEwc_T#ofZ)bS7~vkIYDQG8@rDJg{arfi%PnJQcfL*;?9tQ3oRa*25yv(+&I*!q;LXO5w9CYXOUq7bxP6s z-{JP%Ba)8mzYq~jkdb^`t@^9WuT+Vc;#@Yq~v zZ4Vcax!IEzDI78_IvjH-@h%d2a1v<+xj9L_o1_7eVKRmo8S@(S(8hB9<5#dfzT!{Z zE$pH?-fIg2F$4j?uCZXg_FF>&nr|NiFRGRK5BFCDQq6DN7L`8^20Y3KQTbN;mA>9L zUVgH=;N3)Pbu||V{_X*?QNAoY{qly?}<$g_%UY+s@@T&!_ zXxmrYaHCat0zMg1btylfF@(xQ9WC8(|dQ?Tfc7;iI{cVuhik6_t80IVr`*t%_~Im4@#6(2At?^zl-=9xv#uG z9bM#n@=7b4$I6I~AsgBo!%3y4EtoYer8^LvS!Oe|{K=_|77>2c+t7$(>1UCak1DZW zv-|b0EPQ`(xwyhDSGT^ID2jd2?QnYg>r%lxlSr}Sju@Pev&-(-1tv6L6(P3n+Z%^oOR^`HQdZ(Q~ExLXPOo^6e+Nd221r;Yf8f zZ?5dQ@noj*Ck@~sCConZ7t=0%FT`#SVYu$zkWVy*WIasD|D1)HSXy|AnN|4p>(oXQ zL1wTwqVPgQ==jpgLR3%$8RdZCrtwEFvLE83gs#yCSZr=mjCDi@F_km#>S+^x$7;u_ zn|EvK`!qe2?ukT9 zH!P)xf*yg7&HD2{UdJ;4pPSo^wA>zOdr3S!=GNn%>(L6k3Y!qx2R1&qK2w-M}FF zry1F9aSagt@K}U|8NcAYxs>iHOwNSQ+SFum*bnv^KLa<^R{FKKR;QgNlf+|Tb+@1j zi~?tx|J#y?(!T;BP4s*=0X?1X(6Tq6*aH(9`FD++Ce?Pu* uCMMxplu>>d__WlV9uVdQpWT3+L#|9QWlatdsU#6!wHtS?f4Fu(@P7cv2xV;m literal 0 HcmV?d00001 diff --git a/example/ck_tile/01_fmha/script/benchmark.sh b/example/ck_tile/01_fmha/script/benchmark.sh new file mode 100755 index 000000000..859cff9f6 --- /dev/null +++ b/example/ck_tile/01_fmha/script/benchmark.sh @@ -0,0 +1,32 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_fwd +VALID=0 + +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do + +nhead=$((2048 / $hdim)) # follow fav2 setup +$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3 + +done +done +done + +for perm in 0 1 ; do + +$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 +$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3 + +done \ No newline at end of file diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh new file mode 100755 index 000000000..4dd5c2ae1 --- /dev/null +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -0,0 +1,45 @@ +#!/bin/sh +# TODO: run this script from CK root +BUILD=build +EXE=$BUILD/bin/tile_example_fmha_fwd +KNAME=1 + +export CK_WARMUP=0 +export CK_REPEAT=1 + +COMMON_ARGS='-v=1 -warmup=0 -repeat=1' +# mode=0 +# export HIP_VISIBLE_DEVICES=4 + +for prec in "fp16" "bf16" ; do +for mode in 1 0 ; do +for perm in 0 1 ; do +for vlayout in "r" "c" ; do +for hdim in 32 64 128 256 ; do +for lse in 0 1 ; do +for bias in 0 1 ; do + +# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS + +done +done +done +done +done +done +done + +for perm in 0 1 ; do +for bias in 0 1 ; do +for b in 1 2 ; do +$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS +done +done +done diff --git a/example/ck_tile/01_fmha/utils.hpp b/example/ck_tile/01_fmha/utils.hpp new file mode 100644 index 000000000..e10ae617d --- /dev/null +++ b/example/ck_tile/01_fmha/utils.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core/container/span.hpp" + +enum class mode_enum +{ + batch = 0, + group +}; + +std::ostream& operator<<(std::ostream& stream, mode_enum mode) +{ + return stream << (mode == mode_enum::batch ? "batch" : "group"); +} + +std::vector to_seqstarts(ck_tile::span seqlens) +{ + std::vector seqstarts = {0}; + for(int32_t seqlen : seqlens) + { + seqstarts.push_back(seqstarts.back() + seqlen); + } + assert(seqstarts.size() == seqlens.size() + 1); + return seqstarts; +} + +std::vector generate_seqlens(mode_enum mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) +{ + assert(0 < count); + + std::vector seqlens(count, seqlens_sum); + + if(mode == mode_enum::group && 1 < count) + { + using size_type = std::vector::size_type; + + std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}()); + std::uniform_int_distribution idx_dist(0, count - 1); + auto next_idx = std::bind(idx_dist, std::ref(random_engine)); + + std::uniform_int_distribution step_dist(1, count - 1); + auto next_step = std::bind(step_dist, std::ref(random_engine)); + + for(unsigned repeat = seqlens_sum * (count / 2); 0 < repeat; --repeat) + { + const size_type to_decrease = next_idx(); + // make sure each elements of seqlens is always greater than 0 + if(seqlens[to_decrease] == 1) + { + continue; + } + + const size_type to_increase = (to_decrease + next_step()) % count; + + --seqlens[to_decrease]; + ++seqlens[to_increase]; + } + } + + return seqlens; +} + +std::vector generate_seqstarts(mode_enum mode, + unsigned count, + int32_t seqlens_sum, + std::optional seed = std::nullopt) +{ + return to_seqstarts(generate_seqlens(mode, count, seqlens_sum, seed)); +} + +int env_get_int(const char* var_name, int default_int) +{ + char* v = getenv(var_name); + int r = default_int; + if(v) + r = atoi(v); + return r; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt new file mode 100644 index 000000000..d2b086e04 --- /dev/null +++ b/example/ck_tile/CMakeLists.txt @@ -0,0 +1,5 @@ +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(01_fmha) diff --git a/example/ck_tile/remod.py b/example/ck_tile/remod.py new file mode 100644 index 000000000..fdc0dcf5d --- /dev/null +++ b/example/ck_tile/remod.py @@ -0,0 +1,21 @@ +import pathlib +from pathlib import Path +import subprocess +import os +import copy + +all_files = [] +for p in sorted(Path("./").rglob("*")): + if p.suffix in ['.hpp', '.cpp']: + all_files.append(pathlib.PurePath(p)) + + +# formatting +for x in all_files: + subprocess.Popen(f'dos2unix {str(x)}', shell=True) + cmd = f'clang-format-12 -style=file -i {str(x)}' + #for xp in x.parents: + #print(get_file_base(x)) + subprocess.Popen(cmd, shell=True) + +#print(all_files) diff --git a/include/ck_tile/README.md b/include/ck_tile/README.md new file mode 100644 index 000000000..572e9c7e4 --- /dev/null +++ b/include/ck_tile/README.md @@ -0,0 +1,48 @@ +# ck_tile +## concept +`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator + - tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time. + - tile-based programming model, including tile-level api and the concept of distributed tensor. + +`ck_tile` is independently from the old ck, located under [/include/ck_tile](/include/ck_tile). You don't need to include anything from old CK, `ck_tile` has similiar (indeed almost the same) implementations for users to build operators. We will have a transition period to pull everything from old ck into `ck_tile`, stay tuned. + +## component +`ck_tile` is splitted into several componenets including `core`, `host`, `ops/gemm`, `ops/fmha`... each component you only need to include a single header (e.g `#include "ck_tile/core.hpp"`, `#include "ck_tile/ops/fmha.hpp"`) then you are able to use the function/structure inside (different from old `ck`) + +**[core]** +`ck_tile/core` contains all the basic data structure and function to build the kernel, you can only include this header and build your own operators that utilizing all the basic building blocks introduced in ck. + +`core/container` + - array, store runtime variables with fixed length (tensor index, register buffer, etc...) + - tuple, same as std::tuple, hold different type of data, and one of the solution to achieve multiple buffer. + - sequence, compile time integer sequence used to build various internal structures, or to describe tile size + - other convenient structure build on top of above 3 + +`core/numeric` + - gpu data type like `fp16_t`, `bf16_t`, `fp8_t`... and the conversion between each other + - constexpr integer similiar to std::integral_constant to be used as compile time integer. + - math functions and numeric utilities + +`core/algorithm` + - coordinate transformation system, used to build tensor transform and compile time indexing. This is the core idea introduced in old `ck` to describe how a tensor is build by several basic transform primitives like `merge`/`unmerge`/`embed` etc... and how we indexing into a ND tensor that finally mapped to 1D memory offset. + +`core/tensor` + - tensor descriptor, to describe how a ND tensor + - distributed tensor, describe the storage of this tensor, and the distribution of how a collection of threads collaborately work for this tensor. + - tile level API, including `load_tile`, `store_tile`, `shuffle_tile`, `slice_tile`, etc... + +**[host]** +`ck_tile/host` contains all the host side utilities to launch a kernel, create the device buffer, and some reference implementations. This can be used to create examples (like that under ck_tile example folder) and simple executable to invoke this kernel, so if you only need `ck_tile` to build your own device library then it's OK to not include this. Based on this, it is recommended to include the specific header you needed under this folder to avoid including unwanted headers (e.g, only include `ck_tile/host/kernel_launch.hpp`), unless you are writing a host executable. + +**[ops/gemm, ops/fmha, ops/reduce...]** +our implementation of different device operators. + - warp, warp tile level operator + - block, block tile level operator + - pipeline, pipeline that can achieve a customized tile level mainloop (or epilogue). By switching different pipeline to the kernel template you can have different kind of pipeline optimizations. + - kernel, template interface for users to instantiate a particular kernel + +**[ops/epilogue]** +epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues. + +## examples +currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder. diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp new file mode 100644 index 000000000..bb19c9154 --- /dev/null +++ b/include/ck_tile/core.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/algorithm/cluster_descriptor.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/map.hpp" +#include "ck_tile/core/container/meta_data_buffer.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/span.hpp" +#include "ck_tile/core/container/statically_indexed_array.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/type_convert.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/tensor/buffer_view.hpp" +#include "ck_tile/core/tensor/load_tile.hpp" +#include "ck_tile/core/tensor/null_tensor.hpp" +#include "ck_tile/core/tensor/null_tile_window.hpp" +#include "ck_tile/core/tensor/shuffle_tile.hpp" +#include "ck_tile/core/tensor/slice_tile.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/store_tile.hpp" +#include "ck_tile/core/tensor/sweep_tile.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" +#include "ck_tile/core/tensor/tensor_coordinate.hpp" +#include "ck_tile/core/tensor/tensor_descriptor.hpp" +#include "ck_tile/core/tensor/tensor_view.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" +#include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/ignore.hpp" +#include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/random.hpp" +#include "ck_tile/core/utility/to_sequence.hpp" +#include "ck_tile/core/utility/transpose_vectors.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/unary_element_function.hpp" diff --git a/include/ck_tile/core/README.md b/include/ck_tile/core/README.md new file mode 100644 index 000000000..d2ecfabae --- /dev/null +++ b/include/ck_tile/core/README.md @@ -0,0 +1,18 @@ +# ck_tile/core # + +`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...) + +``` +algorithm/ + coordinate transform and some other reusable algorithm +arch/ + contains some basic device building block like mma, buffer addressing, etc... +container/ + contains basic container data structure, array/sequence/tuple/... +numeric/ + data type, and data type related math +tensor/ + tensor descriptors and tile level API +utility/ + other utility function for both host/device +``` diff --git a/include/ck_tile/core/algorithm/cluster_descriptor.hpp b/include/ck_tile/core/algorithm/cluster_descriptor.hpp new file mode 100644 index 000000000..c59a7c1fa --- /dev/null +++ b/include/ck_tile/core/algorithm/cluster_descriptor.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template ::type> +CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor( + const Lengths& lengths, + ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{}) +{ + constexpr index_t ndim_low = Lengths::size(); + + const auto reordered_lengths = container_reorder_given_new2old(lengths, order); + + const auto low_lengths = generate_tuple( + [&](auto idim_low) { return reordered_lengths[idim_low]; }, number{}); + + const auto transform = make_merge_transform(low_lengths); + + constexpr auto low_dim_old_top_ids = ArrangeOrder{}; + + constexpr auto up_dim_new_top_ids = sequence<0>{}; + + return make_single_stage_tensor_adaptor( + make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp new file mode 100644 index 000000000..71602e5d1 --- /dev/null +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -0,0 +1,1672 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/magic_div.hpp" + +namespace ck_tile { + +enum struct coord_transform_enum +{ + undefined, + pass_through, + pad, + embed, + merge, + unmerge, + replicate, + xor_t, + offset, +}; + +template +struct base_transform +{ + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::undefined; + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return NDimLow; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return NDimUp; } + + // return safe value for vector length/stride, based on compile-time known only + // variables + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths&, + const LowVectorStrides&) + { + if constexpr(NDimUp > 0) + { + array up_vector_lengths{-1}; + array up_vector_strides{-1}; + + return make_tuple(up_vector_lengths, up_vector_strides); + } + else + { + return make_tuple(array{}, array{}); + } + } +}; + +template +struct pass_through : public base_transform<1, 1> +{ + static constexpr auto type_enum = coord_transform_enum::pass_through; + + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr pass_through() = default; + + CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength& low_length) + : up_lengths_{make_tuple(low_length)} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::pass_through; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}]; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + return make_tuple(low_vector_lengths, low_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("pass_through{"); + + // + printf("up_lengths_:"); + print(up_lengths_); + + // + printf("}"); + } +}; + +template +struct pad : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + RightPadLength right_pad_length_; + + CK_TILE_HOST_DEVICE constexpr pad() : up_lengths_{}, left_pad_length_{}, right_pad_length_{} {} + + CK_TILE_HOST_DEVICE constexpr pad(const LowLength& low_length, + const LeftPadLength& left_pad_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)}, + left_pad_length_{left_pad_length}, + right_pad_length_{right_pad_length} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return SkipIsValidCheck; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const + { + return SkipIsValidCheck || + ((idx_up[number<0>{}] >= left_pad_length_) && + (idx_up[number<0>{}] < up_lengths_[number<0>{}] - right_pad_length_)); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("pad{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("left_pad_length_: "); + print(left_pad_length_); + printf(", "); + + // + printf("right_pad_length_: "); + print(right_pad_length_); + + printf("}"); + } +}; + +template +struct left_pad +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + + CK_TILE_HOST_DEVICE constexpr left_pad() = default; + + CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength& low_length, + const LeftPadLength& left_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return SkipIsValidCheck; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[number<0>{}] >= left_pad_length_); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + // TODO: we allow pass through this vector length. If one need per-pixel check, + // should change the guaranteed vector length while creating the tensor view. + // It's up to runtime to check the padding length should be multiple of vector length + return make_tuple(low_vector_lengths, low_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("left_pad{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("left_pad_length_: "); + print(left_pad_length_); + + printf("}"); + } +}; + +template +struct right_pad : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LowLength low_length_; + RightPadLength right_pad_length_; + + CK_TILE_HOST_DEVICE constexpr right_pad() = default; + + CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength& low_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + right_pad_length)}, + low_length_{low_length}, + right_pad_length_{right_pad_length} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}]; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return SkipIsValidCheck; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[number<0>{}] < low_length_); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + // TODO: we allow pass through this vector length. If one need per-pixel check, + // should change the guaranteed vector length while creating the tensor view. + // It's up to runtime to check the padding length should be multiple of vector length + return make_tuple(low_vector_lengths, low_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("right_pad{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("right_pad_length_: "); + print(right_pad_length_); + + printf("}"); + } +}; + +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] +// UpLengths and Coefficients can be either of the followings: +// 1) Tuple of index_t, which is known at run-time, or +// 2) Tuple of number, which is known at compile-time, or +// 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially +// at compile-time +template ::type = false> +struct embed : public base_transform<1, UpLengths::size()> +{ + static constexpr index_t NDimUp = UpLengths::size(); + + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index; + + UpLengths up_lengths_; + Coefficients coefficients_; + + CK_TILE_HOST_DEVICE constexpr embed() = default; + + CK_TILE_HOST_DEVICE constexpr embed(const UpLengths& up_lengths, + const Coefficients& coefficients) + : up_lengths_{up_lengths}, coefficients_{coefficients} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::embed; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { + idx_low(number<0>{}) += idx_up[i] * this->coefficients_[i]; + }); + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp && + LowIdx::size() == 1 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_diff_low(number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}( + [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; }); + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("embed{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("coefficients_: "); + print(coefficients_); + + printf("}"); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_divisor +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(number i) const + { + return magic_division::calculate_magic_numbers(LowLengths{}[i]); + } +}; + +// Implementation of "merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct merge_v2_magic_division : public base_transform +{ + static constexpr index_t NDimLow = LowLengths::size(); + + using LowerIndex = multi_index; + using UpperIndex = multi_index<1>; + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{}))); + + using LowLengthsMagicDivisor = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_divisor{}, + number{})); + + LowLengths low_lengths_; + LowLengthsMagicDivisor low_lengths_magic_divisor_; + UpLengths up_lengths_; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division() = default; + + CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_magic_divisor_{generate_tuple( + [&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); }, + number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))} + { + static_assert(LowerIndex::size() == NDimLow, "wrong!"); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::merge; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[I0]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + magic_division::do_magic_division(tmp, + this->low_lengths_magic_divisor_[i][I0], + this->low_lengths_magic_divisor_[i][I1]); + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + }); + + idx_low(number<0>{}) = tmp; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new) const + { + static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 && + LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + magic_division::do_magic_division(tmp, + this->low_lengths_magic_divisor_[i][I0], + this->low_lengths_magic_divisor_[i][I1]); + + index_t idx_low_old = idx_low[i]; + + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + + idx_diff_low(i) = idx_low[i] - idx_low_old; + }); + + idx_diff_low(number<0>{}) = tmp - idx_low(number<0>{}); + + idx_low(number<0>{}) = tmp; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + array up_vector_lengths{-1}; + array up_vector_strides{-1}; + + up_vector_lengths[0] = low_vector_lengths[number{}]; + up_vector_strides[0] = low_vector_strides[number{}]; + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("merge_v2_magic_division{"); + + // + printf("low_lengths_ "); + print(low_lengths_); + printf(", "); + + // + printf("up_lengths_ "); + print(up_lengths_); + + printf("}"); + } +}; + +// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct merge_v3_division_mod : public base_transform +{ + static constexpr index_t NDimLow = LowLengths::size(); + + using LowerIndex = multi_index; + using UpperIndex = multi_index<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod() = default; + + CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))} + { + static_assert(LowerIndex::size() == NDimLow, "wrong!"); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(number{}) = tmp; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new) const + { + static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 && + LowIdx::size() == NDimLow && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + constexpr auto INm1 = number{}; + + index_t tmp = idx_up_new[I0]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + const index_t tmp2 = idx_low[i]; + idx_low(i) = tmp / this->low_lengths_scan_[i]; + idx_diff_low(i) = idx_low[i] - tmp2; + tmp %= this->low_lengths_scan_[i]; + }); + + const index_t tmp2 = idx_low[INm1]; + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_low[INm1] - tmp2; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + array up_vector_lengths{-1}; + array up_vector_strides{-1}; + + up_vector_lengths[0] = low_vector_lengths[number{}]; + up_vector_strides[0] = low_vector_strides[number{}]; + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("Merge_v3_direct_division_mod{"); + + // + printf("low_lengths_ "); + print(low_lengths_); + printf(", "); + + // + printf("low_lengths_scan_ "); + print(low_lengths_scan_); + printf(", "); + + // + printf("up_lengths_ "); + print(up_lengths_); + + printf("}"); + } +}; + +template +struct unmerge : public base_transform<1, UpLengths::size()> +{ + static constexpr index_t NDimUp = UpLengths::size(); + + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index; + + using UpLengthsScan = + decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{})); + + UpLengths up_lengths_; + UpLengthsScan up_lengths_scan_; + + CK_TILE_HOST_DEVICE constexpr unmerge() = default; + + CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths) + : up_lengths_{up_lengths}, + up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::unmerge; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + if constexpr(!Use24BitIntegerCalculation) + { + idx_low(number<0>{}) = idx_up[number{}]; + + static_for<0, NDimUp - 1, 1>{}( + [&](auto i) { idx_low(number<0>{}) += idx_up[i] * up_lengths_scan_[i]; }); + } + else + { + idx_low(number<0>{}) = idx_up[number{}]; + + static_for<0, NDimUp - 1, 1>{}([&](auto i) { + idx_low(number<0>{}) = + (0x00ffffff & idx_low[number<0>{}]) + + (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]); + }); + } + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) const + { + calculate_lower_index(idx_diff_low, idx_diff_up); + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE static constexpr auto + calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) + { + array up_vector_lengths{-1}; + array up_vector_strides{-1}; + + constexpr auto up_length_last = UpLengths{}[number{}]; + + if constexpr(ck_tile::is_known_at_compile_time::value) + { + if(low_vector_lengths[0] != -1) + { + up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last); + } + } + + up_vector_strides(NDimUp - 1) = low_vector_strides[0]; + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("unmerge{"); + + // + printf("up_lengths_"); + print(up_lengths_); + printf(", "); + + // + printf("up_lengths_scan_"); + print(up_lengths_scan_); + + printf("}"); + } +}; + +template +struct freeze : public base_transform<1, 0> +{ + LowerIndex low_idx_; + + CK_TILE_HOST_DEVICE constexpr freeze() = default; + + CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + + CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& /* idx_up */) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 0, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = low_idx_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& /* idx_low */, + const UpIdx& /* idx_up_new */) + { + idx_diff_low(number<0>{}) = 0; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("freeze{"); + + // + printf("low_idx_: "); + print(low_idx_); + + printf("}"); + } +}; + +// insert a dangling upper dimension without lower dimension +template +struct insert : public base_transform<0, 1> +{ + using UpLengths = decltype(make_tuple(UpperLength{})); + + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr insert() = default; + + CK_TILE_HOST_DEVICE constexpr insert(const UpperLength& up_length) + : up_lengths_{make_tuple(up_length)} + { + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return 0; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return 1; } + + CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::size() == 0 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + } + + template + CK_TILE_HOST_DEVICE static void + update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&) + { + static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == 1 && LowIdx::size() == 0 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + } + + CK_TILE_HOST_DEVICE static constexpr bool IsLinearTransform() { return true; } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("insert{"); + + // + print(up_lengths_); + + printf("}"); + } +}; + +// replicate the original tensor and create a higher dimensional tensor +template +struct replicate : public base_transform<0, UpLengths::size()> +{ + static constexpr index_t NDimUp = UpLengths::size(); + + CK_TILE_HOST_DEVICE constexpr replicate() = default; + + CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths} + { + } + + CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::size() == 0 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + } + + template + CK_TILE_HOST_DEVICE static void + update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&) + { + static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == NDimUp && + LowIdx::size() == 0 && UpIdx::size() == NDimUp, + "wrong! inconsistent # of dimension"); + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("replicate{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + + printf("}"); + } + + // + UpLengths up_lengths_; +}; + +template +struct slice : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); + + UpLengths up_lengths_; + SliceBegin slice_begin_; + SliceEnd slice_end_; + + CK_TILE_HOST_DEVICE constexpr slice() = default; + + CK_TILE_HOST_DEVICE constexpr slice(const LowLength&, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) + : up_lengths_{make_tuple(slice_end - slice_begin)}, + slice_begin_{slice_begin}, + slice_end_{slice_end} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] + slice_begin_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("slice{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("slice_begin_: "); + print(slice_begin_); + printf(", "); + + // + printf("slice_end_: "); + print(slice_end_); + + printf("}"); + } // namespace ck +}; // namespace ck + +/* + * \brief lower_idx = upper_idx % modulus. + * TODO: Need an improved implementation since the modulo operation is expensive. + */ +template +struct modulo : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + using UpLengths = decltype(make_tuple(UpLength{})); + + Modulus modulus_; + UpLengths up_lengths_; + + CK_TILE_HOST_DEVICE constexpr modulo() = default; + + CK_TILE_HOST_DEVICE constexpr modulo(const Modulus& modulus, const UpLength& up_length) + : modulus_{modulus}, up_lengths_{make_tuple(up_length)} + { + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] % modulus_; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& up_idx) const + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + const auto idx_low_old = idx_low; + idx_low[I0] = (up_idx[I0] + idx_diff_up[I0]) % modulus_; + idx_diff_low[I0] = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("Modulus{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + + printf("}"); + } +}; + +// 2D XOR, NOTE: "xor" is a keyword +template +struct xor_t : public base_transform<2, 2> +{ + static constexpr auto type_enum = coord_transform_enum::xor_t; + + using LowerIndex = multi_index<2>; + using UpperIndex = multi_index<2>; + + using UpLengths = LowLengths; + + UpLengths up_lengths_; + RightShift right_shift_; + + CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {} + + CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths, + const RightShift& right_shift) + : up_lengths_{low_lengths}, right_shift_{right_shift} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::xor_t; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 2 && UpIdx::size() == 2, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}]; + + const auto idx_low_1_tmp = + (idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}]; + + const auto idx_low_1 = + (idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp; + + idx_low(number<1>{}) = idx_low_1; + } + + template + CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdxDiff::size() == 2 && UpIdxDiff::size() == 2 && LowIdx::size() == 2 && + UpIdx::size() == 2, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + + calculate_lower_index(idx_low, idx_up); + + idx_diff_low = idx_low - idx_low_old; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */) + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + // MUST be static function + template + CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides( + const LowVectorLengths& low_vector_lengths, + const LowVectorStrides& low_vector_strides) const + { + array up_vector_lengths = low_vector_lengths; + array up_vector_strides = low_vector_strides; + + if constexpr(ck_tile::is_known_at_compile_time::value) + { + if(low_vector_lengths[1] != -1) + { + up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_)); + } + } + + return make_tuple(up_vector_lengths, up_vector_strides); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("xor_t{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("right_shift_: "); + print(right_shift_); + + printf("}"); + } +}; + +template +struct offset : public base_transform<1, 1> +{ + using LowerIndex = multi_index<1>; + using UpperIndex = multi_index<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + OffsetLength offset_length_; + + CK_TILE_HOST_DEVICE constexpr offset() = default; + + CK_TILE_HOST_DEVICE constexpr offset(const LowLength& low_length, + const OffsetLength& offset_length) + : up_lengths_{make_tuple(low_length)}, offset_length_{offset_length} + { + } + + CK_TILE_HOST_DEVICE static constexpr auto get_type_enum() + { + return coord_transform_enum::offset; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; } + + template + CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::size() == 1 && UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(number<0>{}) = idx_up[number<0>{}] + offset_length_; + } + + template + CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&) + { + static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && + UpIdx::size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = number<0>{}; + + idx_diff_low[I0] = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + CK_TILE_HOST_DEVICE static constexpr bool + is_valid_upper_index_always_mapped_to_valid_lower_index() + { + return true; + } + + template + CK_TILE_HOST_DEVICE constexpr bool + is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const + { + return true; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() + { + return ck_tile::is_known_at_compile_time::value && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("offset{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + // + printf("offset_length_: "); + print(offset_length_); + + printf("}"); + } +}; + +//******************************************************************************************************* + +template +CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length) +{ + return pass_through{low_length}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_pad_transform(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad, + bool_constant = bool_constant{}) +{ + return pad{low_length, left_pad, right_pad}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_left_pad_transform(const LowLength& low_length, + const LeftPadLength& left_pad_, + bool_constant = bool_constant{}) +{ + return left_pad{low_length, left_pad_}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_right_pad_transform(const LowLength& low_length, + const RightPadLength& right_pad_, + bool_constant = bool_constant{}) +{ + return right_pad{low_length, right_pad_}; +} + +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths& up_lengths, + const Coefficients& coefficients) +{ + return embed{up_lengths, coefficients}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_merge_transform_v2_magic_division(const LowLengths& low_lengths) +{ + return merge_v2_magic_division{low_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_merge_transform_v3_division_mod(const LowLengths& low_lengths) +{ + return merge_v3_division_mod{low_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_lengths) +{ + return make_merge_transform_v2_magic_division(low_lengths); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_unmerge_transform(const UpLengths& up_lengths, + bool_constant = bool_constant{}) +{ + return unmerge{up_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex& low_idx) +{ + return freeze{low_idx}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex& up_idx) +{ + return insert{up_idx}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths& up_lengths) +{ + return replicate{up_lengths}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) +{ + return slice{low_length, slice_begin, slice_end}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus, + const UpLength& up_length) +{ + return modulo{modulus, up_length}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths, + const RightShift& right_shift) +{ + return xor_t{low_lengths, right_shift}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_length, + const OffsetLength& offset_length) +{ + return offset{low_length, offset_length}; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/algorithm/space_filling_curve.hpp b/include/ck_tile/core/algorithm/space_filling_curve.hpp new file mode 100644 index 000000000..77a635611 --- /dev/null +++ b/include/ck_tile/core/algorithm/space_filling_curve.hpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/statically_indexed_array.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template // # of scalars per access in each dimension +struct space_filling_curve +{ + static constexpr index_t TensorSize = + reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}); + static_assert(0 < TensorSize, + "space_filling_curve should be used to access a non-empty tensor"); + + static constexpr index_t nDim = TensorLengths::size(); + + using Index = multi_index; + + static constexpr index_t ScalarPerVector = + reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{}); + + static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{}; + static constexpr auto dim_access_order = DimAccessOrder{}; + static constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(ordered_access_lengths)), + make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}), + make_tuple(sequence<0>{})); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access() + { + static_assert(TensorLengths::size() == ScalarsPerAccess::size()); + static_assert(TensorLengths{} % ScalarsPerAccess{} == + typename uniform_sequence_gen::type{}); + + return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector; + } + + template + static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number, + number) + { + static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(), + "1D index out of range"); + static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(), + "1D index out of range"); + + constexpr auto idx_head = get_index(number{}); + constexpr auto idx_tail = get_index(number{}); + return idx_tail - idx_head; + } + + template + static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number) + { + static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0"); + return get_step_between(number{}, number{}); + } + + template + static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number) + { + static_assert(AccessIdx1d > 0, "1D index should be larger than 0"); + + return get_step_between(number{}, number{}); + } + + template + static CK_TILE_HOST_DEVICE constexpr Index get_index(number) + { +#if 0 + /* + * \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected. + */ + constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number{})); +#else + + constexpr auto access_strides = + container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{}); + + constexpr auto idx_1d = number{}; + // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the + // idim-th element of multidimensional index. + // All constexpr variables have to be captured by VALUE. + constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr + { + constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr + { + auto res = idx_1d.value; + auto id = 0; + + static_for<0, jdim.value + 1, 1>{}([&](auto kdim) { + id = res / access_strides[kdim].value; + res -= id * access_strides[kdim].value; + }); + + return id; + }; + + constexpr auto id = compute_index_impl(idim); + return number{}; + }; + + constexpr auto ordered_access_idx = generate_tuple(compute_index, number{}); +#endif + constexpr auto forward_sweep = [&]() { + statically_indexed_array forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto idim) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, idim, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(idim) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate multi-dim tensor index + auto idx_md = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto idim) { + ordered_idx(idim) = + !SnakeCurved || forward_sweep[idim] + ? ordered_access_idx[idim] + : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + ScalarsPerAccess{}; + }(); + return idx_md; + } + + // FIXME: rename this function + template + static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number) + { + constexpr auto idx = get_index(number{}); + + return generate_tuple([&](auto i) { return number{}; }, number{}); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp new file mode 100644 index 000000000..53f42a742 --- /dev/null +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -0,0 +1,2031 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/vector_type.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +// 128 bit SGPRs to supply buffer resource in buffer instructions +// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions +struct __attribute__((packed)) buffer_resource +{ + const void* ptr; + uint32_t range; + uint32_t config; +}; + +CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) +{ + buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD}; + return __builtin_bit_cast(int32x4_t, res); +} + +// TODO: glc/slc/... +template +struct buffer_load; +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wundefined-reinterpret-cast" +// TODO: strict aliasing rule seems fail when reinterpret_cast between vector type +// (exp_vector_type(xxx)) +template <> +struct buffer_load<16> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 16); + using mbuf_t = fp32x4_t; + asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<8> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 8); + using mbuf_t = fp32x2_t; + asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<4> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<2> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually + using mbuf_t = float; + asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_load<1> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 0) + { + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template +struct buffer_load_if; + +template <> +struct buffer_load_if<16> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 16); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = fp32x4_t; + static_assert(sizeof(mbuf_t) == sizeof(T)); + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<8> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 8); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = fp32x2_t; + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<4> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 4); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_dword %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<2> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 4); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; + +template <> +struct buffer_load_if<1> +{ + template + CK_TILE_DEVICE void operator()(T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 0) + { + static_assert(sizeof(T) == 4); + auto saved_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile( + "v_cmpx_le_u32 exec, 1, %5\n" + "buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : "+v"(reinterpret_cast(value)) + : "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset), "v"(flag), "s"(saved_exec) + : "memory"); + } +}; +#pragma clang diagnostic pop // "-Wundefined-reinterpret-cast" +template +struct buffer_store; + +template <> +struct buffer_store<16> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 16); + using mbuf_t = fp32x4_t; + asm volatile( + "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<8> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 8); + using mbuf_t = fp32x2_t; + asm volatile( + "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<4> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile( + "buffer_store_dword %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<2> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 2); + using mbuf_t = short; + asm volatile( + "buffer_store_short %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template <> +struct buffer_store<1> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t /*flag*/ = 1) + { + static_assert(sizeof(T) == 4); + using mbuf_t = float; + asm volatile( + "buffer_store_byte %0, %1, %2, %3 offen offset:%4" + : + : "v"(bit_cast(value)), "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset) + : "memory"); + } +}; + +template +struct buffer_store_if; + +template <> +struct buffer_store_if<16> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 16); + auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = fp32x4_t; + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dwordx4 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(bit_cast(value)), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<8> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 8); + auto save_exec = __builtin_amdgcn_read_exec(); + // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch + using mbuf_t = ext_vector_t; + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(bit_cast(value)), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<4> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 4); + auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_dword %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(bit_cast(value)), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<2> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 2); + auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = short; + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_short %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(bit_cast(value)), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +template <> +struct buffer_store_if<1> +{ + template + CK_TILE_DEVICE void operator()(const T& value, + int32x4_t res /*buffer resource*/, + index_t v_offset, + index_t s_offset, + index_t i_offset /*max 0xFFF*/, + index_t flag = 1) + { + static_assert(sizeof(T) == 4); + auto save_exec = __builtin_amdgcn_read_exec(); + using mbuf_t = float; + asm volatile("v_cmpx_le_u32 exec, 1, %5\n" + "buffer_store_byte %0, %1, %2, %3 offen offset:%4\n" + "s_mov_b64 exec %6" + : + : "v"(bit_cast(value)), + "v"(v_offset), + "s"(res), + "s"(s_offset), + "n"(i_offset), + "v"(flag), + "s"(save_exec) + : "memory"); + } +}; + +CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// clang-format off +namespace impl{ + +// can't use "+v" since there could be potential extra move(read/write) +// use "v" can help remove such duplicated moves +// besides, fake this as "memory" operation to force later valu after this fence +// TODO: may have scratch (because this is memory?) +// need to reduce extra move inside compiler +template +CK_TILE_DEVICE void insert_dummy_dep_per_dword(array& b) +{ + static_for<0, b.size(), 1>{}([&](auto i){ + asm volatile(" " : : "v"(b.get(i)) : "memory"); + }); +} +#if 1 +// below specialization just merge size() of dwords into single section +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<2>(array& b) +{ + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<3>(array& b) +{ + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<4>(array& b) +{ + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<8>(array& b) +{ + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), + "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<16>(array& b) +{ + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), + "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), + "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), + "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})) : "memory"); +} + +template<> +CK_TILE_DEVICE void insert_dummy_dep_per_dword<32>(array& b) +{ + asm volatile(" " : : "v"(b.get(number<0>{})), "v"(b.get(number<1>{})), "v"(b.get(number<2>{})), "v"(b.get(number<3>{})), + "v"(b.get(number<4>{})), "v"(b.get(number<5>{})), "v"(b.get(number<6>{})), "v"(b.get(number<7>{})), + "v"(b.get(number<8>{})), "v"(b.get(number<9>{})), "v"(b.get(number<10>{})), "v"(b.get(number<11>{})), + "v"(b.get(number<12>{})), "v"(b.get(number<13>{})), "v"(b.get(number<14>{})), "v"(b.get(number<15>{})), + "v"(b.get(number<16>{})), "v"(b.get(number<17>{})), "v"(b.get(number<18>{})), "v"(b.get(number<19>{})), + "v"(b.get(number<20>{})), "v"(b.get(number<21>{})), "v"(b.get(number<22>{})), "v"(b.get(number<23>{})), + "v"(b.get(number<24>{})), "v"(b.get(number<25>{})), "v"(b.get(number<26>{})), "v"(b.get(number<27>{})), + "v"(b.get(number<28>{})), "v"(b.get(number<29>{})), "v"(b.get(number<30>{})), "v"(b.get(number<31>{})) : "memory"); +} +#endif +CK_TILE_DEVICE void insert_dummy_dep() {} + +template +CK_TILE_DEVICE void insert_dummy_dep(T & buffer) +{ + // TODO: indeed we expect T to be multiple of dword. subdword is always buggy + using da_type = array; + auto & dummy = reinterpret_cast(buffer); + insert_dummy_dep_per_dword(dummy); +} + +template +CK_TILE_DEVICE void insert_dummy_dep(Tx& bx, Ty&... by) +{ + insert_dummy_dep(bx); + insert_dummy_dep(by...); +} +} +// clang-format on +template +CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0, T&... o) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); + impl::insert_dummy_dep(o...); +} + +CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// buffer load i8 +CK_TILE_DEVICE_EXTERN int8_t +llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); + +CK_TILE_DEVICE_EXTERN int8x2_t +llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); + +CK_TILE_DEVICE_EXTERN int8x4_t +llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); + +// buffer load i16 +CK_TILE_DEVICE_EXTERN int16_t +llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); + +CK_TILE_DEVICE_EXTERN int16x2_t +llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); + +CK_TILE_DEVICE_EXTERN int16x4_t +llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); + +// buffer load i32 +CK_TILE_DEVICE_EXTERN int32_t +llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + +CK_TILE_DEVICE_EXTERN int32x2_t +llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); + +CK_TILE_DEVICE_EXTERN int32x4_t +llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +// buffer load fp16 +CK_TILE_DEVICE_EXTERN _Float16 +llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +CK_TILE_DEVICE_EXTERN fp16x2_t +llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +CK_TILE_DEVICE_EXTERN fp16x4_t +llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); + +// buffer load fp32 +CK_TILE_DEVICE_EXTERN float +llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +CK_TILE_DEVICE_EXTERN fp32x2_t +llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + +CK_TILE_DEVICE_EXTERN fp32x4_t +llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + +// buffer store i8 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); + +// buffer store i16 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i16x2(int16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i16x4(int16x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + +// buffer store i32 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +// buffer store fp16 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_fp16(_Float16 vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_fp16x2(fp16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_fp16x4(fp16x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); + +// buffer store fp32 +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_fp32x2(fp32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_store_fp32x4(fp32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); + +// buffer atomic-add fp16 +CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + fp16x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 +CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// buffer atomic-add fp32 +CK_TILE_DEVICE_EXTERN float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); + +// buffer atomic-max fp64 +CK_TILE_DEVICE_EXTERN double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + +CK_TILE_DEVICE void async_buffer_load_dword(void* smem, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t ioffset /*max 0xFFF*/, + index_t /*flag*/ = 0) +{ + asm volatile("buffer_load_dword %1, %2, %3 offen offset:%4 lds" + : "=r"(smem) /*dummy dependency for smem*/ + : "v"(voffset), "s"(rsrc), "s"(soffset), "n"(ioffset) + : "memory"); +} + +CK_TILE_DEVICE void async_buffer_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +// memory coherency bit for buffer store/load instruction +// check ISA manual for each GFX target +// e.g. for +// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, +// page 67~68 +enum struct amd_buffer_coherence_enum +{ + coherence_default = 0, // default value + glc = 1, + slc = 2, + glc_slc = 3, +}; + +template +CK_TILE_DEVICE thread_buffer +amd_buffer_load_impl_with_bytes(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + using rtn_type = thread_buffer; + + if constexpr(N == 1) + { + return bit_cast(llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 2) + { + + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 4) + { + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 8) + { + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 16) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 32) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = tmp0; + tmp.template get_as()(number<1>{}) = tmp1; + + return bit_cast(tmp); + } + else if constexpr(N == 64) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp2 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp3 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int32_t), + static_cast(coherence)); + + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = tmp0; + tmp.template get_as()(number<1>{}) = tmp1; + tmp.template get_as()(number<2>{}) = tmp2; + tmp.template get_as()(number<3>{}) = tmp3; + + return bit_cast(tmp); + } +} + +#ifndef BUFFER_LOAD_USE_INLINEASM +#define BUFFER_LOAD_USE_INLINEASM 0 +#endif + +template +CK_TILE_DEVICE thread_buffer amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using rtn_type = thread_buffer; + + if constexpr(std::is_same::value) // fp32 + { + if constexpr(N == 1) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 2) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 4) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 8) + { + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.template get_as()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + return tmp; + } + else if constexpr(N == 16) + { + thread_buffer tmp; + + tmp.template get_as()(number<0>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.template get_as()(number<1>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + + tmp.template get_as()(number<2>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(float), + static_cast(coherence)); + + tmp.template get_as()(number<3>{}) = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(float), + static_cast(coherence)); + + return tmp; + } + } + else if constexpr(std::is_same::value) // fp16 + { + if constexpr(N == 1) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 2) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 4) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 8) + { + // use fp32 load to mimic fp16 load + fp32x4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + } + else if constexpr(std::is_same::value) // bf16 + { + if constexpr(N == 1) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 2) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 4) + { + return bit_cast( + llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence))); + } + else if constexpr(N == 8) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + } + else // other datatype + { + auto raw_data = amd_buffer_load_impl_with_bytes( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); + + return bit_cast(raw_data); + } +} + +template +CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer& dst, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t flag = 0) +{ + constexpr index_t bytes = sizeof(T) * N; + static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, + "wrong! not supported by buffer_load instruction"); + + using type = thread_buffer; + if constexpr(oob_conditional_check) + { + buffer_load_if{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + } + else + { + buffer_load{}( + dst, src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0, flag); + } +} + +template +CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem, + int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset, + index_t src_immediate_addr_offset = 0) +{ + static_assert(sizeof(T) * N == 4, "wrong! not implemented vector size"); + + async_buffer_load_dword(smem, + src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + src_immediate_addr_offset); +} + +template +CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i8(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 16) + { + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 32) + { + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + } + else if constexpr(N == 64) + { + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); + } +} + +template +CK_TILE_DEVICE void amd_buffer_store_impl(const thread_buffer src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + if constexpr(std::is_same::value) // fp32 + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_fp32x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + llvm_amdgcn_raw_buffer_store_fp32x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(float), + static_cast(coherence)); + } + } + else if constexpr(std::is_same::value) // fp16 + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp16(bit_cast<_Float16>(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { +#if 0 + thread_buffer tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(fp16_t), + static_cast(coherence)); +#else + llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); +#endif + } + } + else if constexpr(std::is_same::value) // bf16 + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_i16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i16x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_i16x4( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i16x4( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(bf16_t), + static_cast(coherence)); + } + } + else + { + using r_t = thread_buffer; + + amd_buffer_store_impl_with_bytes(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); + } +} + +template +CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer& dst_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset, + index_t is_valid_element = 1) +{ + constexpr index_t bytes = sizeof(T) * N; + static_assert(bytes == 1 || bytes == 2 || bytes == 4 || bytes == 8 || bytes == 16, + "wrong! not supported by buffer_store instruction"); + + using type = thread_buffer; + if constexpr(oob_conditional_check) + { + buffer_store_if{}(dst_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0, + is_valid_element); + } + else + { + buffer_store{}(dst_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } +} + +template +CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer& src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)) || + (std::is_same::value && (N == 2 || N == 4 || N == 8)) || + (std::is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + src_thread_data.template get_as()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(fp16x2_t), + 0); + }); + } + else if constexpr(N == 8) + { + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + src_thread_data.template get_as()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(fp16x2_t), + 0); + }); + } + } + else if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + +template +CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((std::is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(std::is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64( + src_thread_data.template get_as()[number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +// oob_conditional_check : dynamic check if out-of-bound +template +CK_TILE_DEVICE thread_buffer +amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = [&]() { + if constexpr(oob_conditional_check) + return src_thread_element_valid ? 0 : 0x80000000; + else + return 0; + }(); + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#else + thread_buffer tmp = + amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); + if constexpr(oob_conditional_check) + return src_thread_element_valid ? tmp : thread_buffer{numeric::zero()}; + else + return tmp; +#endif +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE thread_buffer +amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size, + T customized_value) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + thread_buffer tmp = + amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); + + if constexpr(oob_conditional_check) + return src_thread_element_valid ? tmp : thread_buffer{customized_value}; + else + return tmp; +} + +template +CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer& dst, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size, + index_t is_valid_element = 0) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_buffer_load_raw_impl( + dst, src_wave_buffer_resource, src_thread_addr_offset, 0, is_valid_element); +} + +// unfortunately async copy can not make sure invalid data is zero inside LDS +// ... unless people manually write zero to LDS at the proper address. +// so not support invalid_element check for now. +// buffer_load OOB still working. +template +CK_TILE_DEVICE void amd_async_buffer_load_with_oob(T* smem, + const T* p_src_wave, + index_t src_thread_element_offset, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size * sizeof(T)); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + amd_async_buffer_load_impl( + smem, src_wave_buffer_resource, src_thread_addr_offset, 0, 0); +} + +// buffer_store requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE void amd_buffer_store(const thread_buffer& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = [&]() { + if constexpr(oob_conditional_check) + return dst_thread_element_valid ? 0 : 0x80000000; + else + return 0; + }(); + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if constexpr(oob_conditional_check) + { + if(dst_thread_element_valid) + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } + } + else + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +template +CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + amd_buffer_store_raw_impl(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + 0, + dst_thread_element_valid); +} + +// buffer_atomic_add requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer& src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T)); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + +#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// Direct loads from global to LDS. +CK_TILE_DEVICE_EXTERN void +llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, + __attribute__((address_space(3))) uint32_t* lds_ptr, + index_t size, + index_t voffset, + index_t soffset, + index_t offset, + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + +template +CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, + const index_t global_offset, + T* lds_base_ptr, + const index_t lds_offset, + const bool is_valid, + const index_t src_element_space_size) +{ + // Direct loads require that each thread reads and writes exactly a single DWORD. + constexpr auto dword_bytes = 4; + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; + static_assert(bytes_per_thread == dword_bytes); + + const uint32_t* global_ptr = + reinterpret_cast(reinterpret_cast(global_base_ptr)); + const int32x4_t src_resource = + make_wave_buffer_resource(global_ptr, src_element_space_size * sizeof(T)); + const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; + +#if CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource)); +#else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); + + llvm_amdgcn_raw_buffer_load_lds( + src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); +#endif +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp new file mode 100644 index 000000000..888f0e728 --- /dev/null +++ b/include/ck_tile/core/arch/arch.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" + +namespace ck_tile { + +enum struct address_space_enum +{ + generic, + global, + lds, + sgpr, + vgpr, +}; + +enum struct memory_operation_enum +{ + set, + atomic_add, + atomic_max, + add +}; + +CK_TILE_HOST_DEVICE constexpr index_t get_warp_size() +{ + // warpSize is defined by HIP + return warpSize; +} + +CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; } + +CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; } + +// TODO: deprecate these +CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; } + +CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + +CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; } + +// Use these instead +CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); } + +CK_TILE_DEVICE index_t get_warp_id() +{ + return __builtin_amdgcn_readfirstlane(threadIdx.x / get_warp_size()); +} + +CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; } + +CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } + +CK_TILE_DEVICE void block_sync_lds() +{ +#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +#else + __syncthreads(); +#endif +} + +CK_TILE_DEVICE void block_sync_lds_direct_load() +{ + asm volatile("\ + s_waitcnt vmcnt(0) \n \ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +} + +CK_TILE_DEVICE void s_nop() +{ +#if 1 + asm volatile("\ + s_nop 0 \n \ + " ::); +#else + __builtin_amdgcn_sched_barrier(0); +#endif +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/arch/utility.hpp b/include/ck_tile/core/arch/utility.hpp new file mode 100644 index 000000000..42508e66a --- /dev/null +++ b/include/ck_tile/core/arch/utility.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" + +#include + +namespace ck_tile { + +// TODO: we have "memory" clobber here because this inline asm is used for async copy +CK_TILE_DEVICE void m0_set_with_memory(index_t v) +{ + asm volatile("s_mov_b32 m0, %0" : : "s"(v) : "memory"); +} + +// NOTE: this is an immediate value +CK_TILE_DEVICE void m0_inc_with_memory(index_t v) +{ + asm volatile("s_add_u32 m0, %0, m0" : : "n"(v) : "memory"); +} + +template +CK_TILE_DEVICE T warp_shuffle_up(const T& v_local, uint32_t lane_delta) +{ +#if 0 + return __shfl_up(v_local, lane_delta); +#elif 1 + static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); + + const uint32_t wrap_around_lane_delta = warpSize - lane_delta; + + const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( + (__lane_id() << 2) + (wrap_around_lane_delta << 2), bit_cast(v_local)); + + return bit_cast(v_remote_tmp); +#endif +} + +template +CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta) +{ +#if 0 + return __shfl_down(v_local, lane_delta); +#elif 1 + static_assert(sizeof(T) == sizeof(int32_t), "wrong!"); + + const int32_t v_remote_tmp = __builtin_amdgcn_ds_bpermute( + (__lane_id() << 2) + (lane_delta << 2), bit_cast(v_local)); + + return bit_cast(v_remote_tmp); +#endif +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp new file mode 100644 index 000000000..d915df6e4 --- /dev/null +++ b/include/ck_tile/core/config.hpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif + +#ifdef __HIPCC__ +#define CK_TILE_HOST inline __host__ +#define CK_TILE_DEVICE inline __device__ +#define CK_TILE_HOST_DEVICE inline __host__ __device__ +#define CK_TILE_DEVICE_EXTERN __device__ +#else +#define CK_TILE_HOST inline +#define CK_TILE_DEVICE inline +#define CK_TILE_HOST_DEVICE inline +#define CK_TILE_DEVICE_EXTERN +#endif + +#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE +#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code +#endif + +#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0 +#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 +#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 + +#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT +#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE +#endif + +#define CK_TILE_FLOAT_TO_FP8_STANDARD 0 +#define CK_TILE_FLOAT_TO_FP8_STOCHASTIC 1 + +#ifndef CK_TILE_FLOAT_TO_FP8_DEFAULT +#define CK_TILE_FLOAT_TO_FP8_DEFAULT CK_TILE_FLOAT_TO_FP8_STANDARD +#endif + +// in the old rocm period, we have to use tuple array implementation to implement this +// so turn on the _USE_TUPLE if meet compiler error, otherwise _USE_ARRAY by default. +#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0 +#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1 +#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT +#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE +#endif + +#define CK_TILE_THREAD_BUFFER_USE_ARRAY 0 +#define CK_TILE_THREAD_BUFFER_USE_TUPLE 1 +#ifndef CK_TILE_THREAD_BUFFER_DEFAULT +#define CK_TILE_THREAD_BUFFER_DEFAULT CK_TILE_THREAD_BUFFER_USE_ARRAY +#endif + +#ifndef CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST +#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE +// if using tuple-array as thread_buffer implementation, need to support {} brace init +// ... with similiar behavior as array +#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 1 +#else +#define CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST 0 +#endif +#endif + +#ifndef CK_TILE_USE_LAUNCH_BOUNDS +#define CK_TILE_USE_LAUNCH_BOUNDS 1 +#endif + +#ifndef CK_TILE_TIME_KERNEL +#define CK_TILE_TIME_KERNEL 1 +#endif + +#define CK_TILE_MAX_THREAD_PER_BLOCK 256 +#define CK_TILE_MIN_BLOCK_PER_CU 2 + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK +#define CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM +#define CK_TILE_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1 +#endif + +#ifndef CK_TILE_USE_AMD_BUFFER_LOAD +#define CK_TILE_USE_AMD_BUFFER_LOAD 1 +#endif + +#ifndef CK_TILE_USE_AMD_BUFFER_STORE +#define CK_TILE_USE_AMD_BUFFER_STORE 1 +#endif + +#ifndef CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 +#endif + +// buffer atomic add: floating point +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#else // for GPU code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 +#endif + +#if(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__)) // for GPU code +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 +#else +#define CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS +#define CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 +#endif + +#ifndef CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE +#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 +#endif + +#ifndef CK_TILE_DEBUG_LOG +#define CK_TILE_DEBUG_LOG 0 +#endif + +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0xffffffff +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \ + defined(__gfx942__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(__gfx1030__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code +#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000 +#endif + +#ifndef CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#define CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 +#endif + +#ifndef CK_TILE_USE_SUBDWORD_TILE_CAST +#define CK_TILE_USE_SUBDWORD_TILE_CAST 0 +#endif diff --git a/include/ck_tile/core/container/array.hpp b/include/ck_tile/core/container/array.hpp new file mode 100644 index 000000000..c272b01f5 --- /dev/null +++ b/include/ck_tile/core/container/array.hpp @@ -0,0 +1,251 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +// use aggregate initialization for this type +// e.g. array buf {0}; => {0, 0, 0, 0}, clean +// array buf {3, 2}; => {3, 2, 2, 2} (not {3,2,0,0}) +// use make_array_with({...}) to construct an array with compatible behavior as old ck +// TODO: manually added constructor same as old ck +template +struct array +{ + using value_type = T_; + static constexpr index_t N = N_; + // TODO: do we need this? + // using bulk_type = uint8_t __attribute__((ext_vector_type(N * sizeof(value_type)))); + // union { + value_type data[N]; + // bulk_type __content; + //}; + CK_TILE_HOST_DEVICE constexpr array() : data{} {} + // TODO: will initialize the data[] with the last value repeatedly + // behavior different from std + CK_TILE_HOST_DEVICE constexpr array(std::initializer_list ilist) + { + constexpr index_t list_size = std::initializer_list{}.size(); + static_assert(list_size <= N, "out of bound"); + + index_t i = 0; + value_type vlast = value_type{}; + + for(const value_type& val : ilist) + { + data[i] = val; + vlast = val; + ++i; + } + for(; i < N; ++i) + { + data[i] = vlast; + } + } + + template || + std::is_constructible_v>> + CK_TILE_HOST_DEVICE explicit constexpr array(Y c) + { + for(auto i = 0; i < size(); i++) + data[i] = static_cast(c); + } + + // template + // CK_TILE_HOST_DEVICE constexpr array(const array& o) + // { + // // static_assert(ArrayType::size() == size(), "wrong! size not the same"); + // __content = o.__content; + // } + // CK_TILE_HOST_DEVICE constexpr array& operator=(const array& o) + // { + // // static_assert(ArrayType::size() == size(), "wrong! size not the same"); + // __content = o.__content; + // return *this; + // } + + CK_TILE_HOST_DEVICE static constexpr auto size() { return N; } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; } + + // clang-format off + CK_TILE_HOST_DEVICE constexpr auto& get() { return data; } + CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data; } + CK_TILE_HOST_DEVICE constexpr auto& get(index_t i) { return data[i]; } + CK_TILE_HOST_DEVICE constexpr const auto& get(index_t i) const { return data[i]; } + template CK_TILE_HOST_DEVICE constexpr auto& get() { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr const auto& get() const { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr auto& get(number) { return data[I]; } + template CK_TILE_HOST_DEVICE constexpr const auto& get(number) const { return data[I]; } + + CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); } + CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); } + template CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); } + template CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } + template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return get(I); } + template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return get(I); } + + CK_TILE_HOST_DEVICE constexpr const value_type& operator[](index_t i) const { return get(i); } + CK_TILE_HOST_DEVICE constexpr value_type& operator[](index_t i) { return get(i); } + CK_TILE_HOST_DEVICE constexpr value_type& operator()(index_t i) { return get(i); } // TODO: compatible +#if 0 + template + CK_TILE_HOST_DEVICE constexpr auto operator=(const ArrayLike& arr) + { + static_assert(ArrayLike::size() == size(), "wrong! size not the same"); + for(index_t i = 0; i < size(); ++i) + { + data[i] = arr[i]; + } + return *this; + } +#endif + // type punning (strict aliasing) member functions for read/write + // aliasing this array of type "T", "N" elements + // as array of type "Tx", sizeof(T)*N/sizeof(Tx) elements +#define AR_AS_COM_() \ + static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \ + constexpr int vx = sizeof(value_type) * N / sizeof(Tx) + + template CK_TILE_HOST_DEVICE constexpr auto& get_as() + { AR_AS_COM_(); return reinterpret_cast&>(data); } + template CK_TILE_HOST_DEVICE constexpr const auto& get_as() const + { AR_AS_COM_(); return reinterpret_cast&>(data); } + + // below index is for index *AFTER* type convert, not before + template CK_TILE_HOST_DEVICE constexpr auto& get_as(index_t i) + { AR_AS_COM_(); return reinterpret_cast&>(data).at(i); } + template CK_TILE_HOST_DEVICE constexpr const auto& get_as(index_t i) const + { AR_AS_COM_(); return reinterpret_cast&>(data).at(i); } + template CK_TILE_HOST_DEVICE constexpr auto& get_as(number) + { AR_AS_COM_(); return reinterpret_cast&>(data).at(number{}); } + template CK_TILE_HOST_DEVICE constexpr const auto& get_as(number) const + { AR_AS_COM_(); return reinterpret_cast&>(data).at(number{}); } + + template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) + { AR_AS_COM_(); reinterpret_cast&>(data).at(i) = x; } + template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) + { AR_AS_COM_(); reinterpret_cast&>(data).at(number{}) = x; } +#undef AR_AS_COM_ + // clang-format on +}; + +// empty Array + +template +struct array +{ + using value_type = T; + + CK_TILE_HOST_DEVICE constexpr array() {} + CK_TILE_HOST_DEVICE static constexpr index_t size() { return 0; } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return is_static_v; }; + CK_TILE_HOST_DEVICE void print() const { printf("array{size: 0, data: []}"); } +}; + +template +struct vector_traits; + +// specialization for array +template +struct vector_traits> +{ + using scalar_type = T; + static constexpr index_t vector_size = N; +}; + +namespace details { +template +struct is_ref_wrapper : std::false_type +{ +}; +template +struct is_ref_wrapper> : std::true_type +{ +}; + +template +using not_ref_wrapper = std::negation>>; + +template +struct return_type_helper +{ + using type = D; +}; +template +struct return_type_helper : std::common_type +{ + static_assert(std::conjunction_v...>, + "Ts cannot contain reference_wrappers when D is void"); +}; + +template +using return_type = array::type, sizeof...(Ts)>; +} // namespace details + +template +CK_TILE_HOST_DEVICE constexpr details::return_type make_array(Ts&&... ts) +{ + return {std::forward(ts)...}; +} + +// // make empty array +// template +// CK_TILE_HOST_DEVICE constexpr auto make_array() +// { +// return array{}; +// } + +// compatible with old ck's initializer, make an array and fill it withe the last element from +// initializer_list +template +CK_TILE_HOST_DEVICE constexpr auto make_array_with(std::initializer_list ilist) +{ + return array(ilist); +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator==(const array& a, const array& b) +{ + bool same = true; + + for(index_t i = 0; i < Size; ++i) + { + if(a[i] != b[i]) + { + same = false; + break; + } + } + + return same; +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator!=(const array& a, const array& b) +{ + return !(a == b); +} + +template +CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x) +{ + static_assert(N <= X::size(), ""); + + array arr; + + static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; }); + + return arr; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp new file mode 100644 index 000000000..474eda80d --- /dev/null +++ b/include/ck_tile/core/container/container_helper.hpp @@ -0,0 +1,499 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/map.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array& a, const TData& x) +{ + array r; + static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + r[number{}] = x; + return r; +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_push_front(const tuple& a, const T& x) +{ + return container_concat(make_tuple(x), a); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_push_back(const tuple& a, const T& x) +{ + return container_concat(a, make_tuple(x)); +} + +// reorder array +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_new2old(const array& old_array, sequence /*new2old*/) +{ + static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + return make_array>(old_array[IRs]...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_old2new(const array& old_array, sequence old2new) +{ + return container_reorder_given_new2old( + old_array, typename sequence_map_inverse::type{}); +} + +// reorder array +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_new2old(const array& old_array, + const map& new2old) +{ + array new_array; + + for(const auto& [new_pos, old_pos] : new2old) + { + new_array(new_pos) = old_array[old_pos]; + } + + return new_array; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reorder_given_old2new(const array& old_array, + const map& old2new) +{ + array new_array; + + for(const auto& [old_pos, new_pos] : old2new) + { + new_array(new_pos) = old_array[old_pos]; + } + + return new_array; +} + +// reorder tuple +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(const tuple& old_tuple, + sequence /*new2old*/) +{ + static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_tuple(old_tuple[number{}]...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(const tuple& old_tuple, + sequence old2new) +{ + return container_reorder_given_new2old( + old_tuple, typename sequence_map_inverse::type{}); +} + +// reorder sequence +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_new2old(sequence /* old_seq */, + sequence /*new2old*/) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return sequence::at(number{})...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_reorder_given_old2new(sequence old_seq, + sequence /* old2new */) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + constexpr auto new2old = typename sequence_map_inverse>::type{}; + + return container_reorder_given_new2old(old_seq, new2old); +} + +#if 0 +// rocm-4.1 compiler would crash for recursive lambda +template +CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + number = number<0>{}, + number = number{}, + number = number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto r_old) { + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + // recursively call f/fs + return fs(fs, i + number{}, r_new); + } + else + { + return r_new; + } + }; + + // start recursion + return f(f, number{}, init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +CK_TILE_HOST_DEVICE constexpr auto container_reduce_impl( + const Container& x, Reduce reduce, ROld r_old, number i, number, number) +{ + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + return container_reduce_impl( + x, reduce, r_new, i + number{}, number{}, number{}); + } + else + { + return r_new; + } +} + +// rocm-4.1 compiler would crash for recursive lambda +// container reduce with initial value +template +CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + number = number<0>{}, + number = number{}, + number = number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + if constexpr(IEnd > IBegin) + { + return container_reduce_impl( + x, reduce, init, number{}, number{}, number{}); + } + else + { + return init; + } +} +#endif + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_inclusive_scan(const array& x, Reduce f, TData init) +{ + array y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[number<0>{}]); + y(number<0>{}) = r; + + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const array& x, Reduce f, Init init) +{ +#if 0 + array y; + + TData r = init; + + static_for{}([&](auto i) { + y(i) = r; + r = f(r, x[i]); + }); + + y(number<0>{}) = r; + + return y; +#else + array y; + + TData r = init; + + for(index_t i = NSize - 1; i > 0; --i) + { + y(i) = r; + r = f(r, x[i]); + } + + y(0) = r; + + return y; +#endif +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const sequence& seq, Reduce f, number) +{ + return reverse_exclusive_scan_sequence(seq, f, number{}); +} + +#if 0 +// rocm4.1 compiler would crash with recursive lambda +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto y_old, auto r_old) { + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return fs(fs, i - number<1>{}, y_new, r_new); + } + else + { + return y_new; + } + }; + + // start recursion + return f(f, number{}, make_tuple(init), init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +CK_TILE_HOST_DEVICE constexpr auto container_reverse_exclusive_scan_impl( + const tuple& x, Reduce reduce, number i, YOld y_old, ROld r_old) +{ + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return container_reverse_exclusive_scan_impl(x, reduce, i - number<1>{}, y_new, r_new); + } + else + { + return y_new; + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_exclusive_scan(const tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + return container_reverse_exclusive_scan_impl( + x, reduce, number{}, make_tuple(init), init); +} +#endif + +// TODO: update to like container_reverse_exclusive_scan to deal with tuple of Numebr<> +template +CK_TILE_HOST_DEVICE constexpr auto +container_reverse_inclusive_scan(const tuple& x, Reduce f, TData init) +{ + constexpr index_t NSize = sizeof...(Xs); + + tuple y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[number<0>{}]); + y(number<0>{}) = r; + + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const X& x, const Ys&... ys) +{ + return container_concat(x, container_concat(ys...)); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const array& ax, const array& ay) +{ + return unpack2( + [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const tuple& tx, const tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); +} + +template +CK_TILE_HOST_DEVICE constexpr auto container_concat(const Container& x) +{ + return x; +} + +template +CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const array& arr, sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + return make_array(arr[Is]...); + } + else + { + return array{}; + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto get_container_subset(const tuple& tup, sequence) +{ + static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + return make_tuple(tup[number{}]...); + } + else + { + return tuple<>{}; + } +} + +template +CK_TILE_HOST_DEVICE constexpr void +set_container_subset(array& y, sequence picks, const array& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + for(index_t i = 0; i < picks.size(); ++i) + { + y(picks[i]) = x[i]; + } + } +} + +template +CK_TILE_HOST_DEVICE constexpr void set_container_subset(Y& y, sequence picks, const X& x) +{ + static_assert(Y::size() >= sizeof...(Is) && X::size() == sizeof...(Is), "wrong! size"); + + if constexpr(sizeof...(Is) > 0) + { + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); + } +} + +// return the index of first occurance in the sequence. +// return seq.size(), if not found +template +constexpr index_t container_find(sequence seq, index_t value) +{ + for(auto i = 0; i < seq.size(); i++) + { + if(seq[i] == value) + return i; + } + + return seq.size(); +} + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_to_tuple_of_number(sequence) +{ + using Seq = sequence; + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Seq::at(i); + return number{}; + }, + number{}); +} + +#if 0 +#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \ + [a_of_b_impl, a_size, bs_sizes] { \ + return ck_tile::generate_tuple( \ + [=](auto i) { \ + constexpr auto b_impl = a_of_b_impl[i]; \ + constexpr index_t b_size = bs_sizes[i]; \ + constexpr auto b = TO_SEQUENCE(b_impl, b_size); \ + return b; \ + }, \ + ck_tile::number{}); \ + }() +#else +// constexpr index_t can't be captured "-Wunused-lambda-capture" +// TODO: this is ugly +#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes) \ + [a_of_b_impl, bs_sizes] { \ + return ck_tile::generate_tuple( \ + [=](auto i) { \ + constexpr auto b_impl = a_of_b_impl[i]; \ + constexpr index_t b_size = bs_sizes[i]; \ + constexpr auto b = TO_SEQUENCE(b_impl, b_size); \ + return b; \ + }, \ + ck_tile::number{}); \ + }() +#endif + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/map.hpp b/include/ck_tile/core/container/map.hpp new file mode 100644 index 000000000..87b180caf --- /dev/null +++ b/include/ck_tile/core/container/map.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" + +namespace ck_tile { + +// naive map +template +struct map +{ + using pair_type = tuple; + using impl_type = array; + + impl_type impl_; + index_t size_; + + struct iterator + { + impl_type& impl_; + index_t pos_; + + CK_TILE_HOST_DEVICE constexpr iterator(impl_type& impl, index_t pos) + : impl_{impl}, pos_{pos} + { + } + + CK_TILE_HOST_DEVICE constexpr iterator& operator++() + { + pos_++; + return *this; + } + + CK_TILE_HOST_DEVICE constexpr bool operator!=(const iterator& other) const + { + return other.pos_ != pos_; + } + + CK_TILE_HOST_DEVICE constexpr pair_type& operator*() { return impl_.at(pos_); } + }; + + struct const_iterator + { + const impl_type& impl_; + index_t pos_; + + CK_TILE_HOST_DEVICE constexpr const_iterator(const impl_type& impl, index_t pos) + : impl_{impl}, pos_{pos} + { + } + + CK_TILE_HOST_DEVICE constexpr const_iterator& operator++() + { + pos_++; + + return *this; + } + + CK_TILE_HOST_DEVICE constexpr bool operator!=(const const_iterator& other) const + { + return other.pos_ != pos_; + } + + CK_TILE_HOST_DEVICE constexpr const pair_type& operator*() const { return impl_.at(pos_); } + }; + + CK_TILE_HOST_DEVICE constexpr map() : impl_{}, size_{0} {} + + CK_TILE_HOST_DEVICE constexpr index_t size() const { return size_; } + + CK_TILE_HOST_DEVICE void clear() { size_ = 0; } + + CK_TILE_HOST_DEVICE constexpr index_t find_position(const key& k) const + { + for(index_t i = 0; i < size(); i++) + { + if(impl_[i].template at<0>() == k) + { + return i; + } + } + + return size_; + } + + CK_TILE_HOST_DEVICE constexpr const_iterator find(const key& k) const + { + return const_iterator{impl_, find_position(k)}; + } + + CK_TILE_HOST_DEVICE constexpr iterator find(const key& k) + { + return iterator{impl_, find_position(k)}; + } + + CK_TILE_HOST_DEVICE constexpr const data& operator[](const key& k) const + { + const auto it = find(k); + + // FIXME + // assert(it.pos_ < size()); + + return impl_[it.pos_].template at<1>(); + } + + CK_TILE_HOST_DEVICE constexpr data& operator()(const key& k) + { + auto it = find(k); + + // if entry not found + if(it.pos_ == size()) + { + impl_(it.pos_).template at<0>() = k; + size_++; + } + + // FIXME + // assert(size_ <= max_size); + + return impl_(it.pos_).template at<1>(); + } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr const_iterator begin() const { return const_iterator{impl_, 0}; } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr const_iterator end() const + { + return const_iterator{impl_, size_}; + } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr iterator begin() { return iterator{impl_, 0}; } + + // WARNING: needed by compiler for C++ range-based for loop only, don't use this function! + CK_TILE_HOST_DEVICE constexpr iterator end() { return iterator{impl_, size_}; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("map{size_: %d, ", size_); + // + printf("impl_: ["); + // + for(const auto& [k, d] : *this) + { + printf("{key: "); + print(k); + printf(", data: "); + print(d); + printf("}, "); + } + // + printf("]"); + // + printf("}"); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/meta_data_buffer.hpp b/include/ck_tile/core/container/meta_data_buffer.hpp new file mode 100644 index 000000000..7493b93d8 --- /dev/null +++ b/include/ck_tile/core/container/meta_data_buffer.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include + +namespace ck_tile { + +// TODO: this structure is not intented to be used by user +template +struct meta_data_buffer +{ + CK_TILE_HOST_DEVICE constexpr meta_data_buffer() : buffer_{}, size_{0} {} + + template + CK_TILE_HOST_DEVICE constexpr meta_data_buffer(const X& x, const Xs&... xs) + : buffer_{}, size_{0} + { + push(x, xs...); + } + + template + CK_TILE_HOST_DEVICE constexpr void push(const T& data) + { + if constexpr(!std::is_empty_v) + { + constexpr index_t size = sizeof(T); + + auto tmp = bit_cast>(data); + + for(int i = 0; i < size; i++) + { + buffer_(size_) = tmp[i]; + + size_++; + } + } + } + + template + CK_TILE_HOST_DEVICE constexpr void push(const X& x, const Xs&... xs) + { + push(x); + push(xs...); + } + + template + CK_TILE_HOST_DEVICE constexpr T pop(index_t& pos) const + { + T data; + + if constexpr(!std::is_empty_v) + { + constexpr index_t size = sizeof(T); + + array tmp; + + for(int i = 0; i < size; i++) + { + tmp(i) = buffer_[pos]; + + pos++; + } + + data = bit_cast(tmp); + } + + return data; + } + + template + CK_TILE_HOST_DEVICE constexpr T get(index_t pos) const + { + constexpr index_t size = sizeof(T); + + array tmp; + + for(int i = 0; i < size; i++) + { + tmp(i) = buffer_[pos]; + + pos++; + } + + auto data = bit_cast(tmp); + + return data; + } + + // + array buffer_; + index_t size_ = 0; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/multi_index.hpp b/include/ck_tile/core/container/multi_index.hpp new file mode 100644 index 000000000..921c590df --- /dev/null +++ b/include/ck_tile/core/container/multi_index.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +// Don't use tihs directly. This is for old CK's internal usage, +// in the future always use array instead +template +using multi_index = array; + +template +CK_TILE_HOST_DEVICE constexpr auto make_multi_index(Xs&&... xs) +{ + return make_array(index_t{xs}...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+=(multi_index& y, const X& x) +{ + static_assert(X::size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-=(multi_index& y, const X& x) +{ + static_assert(X::size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); + return y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(const multi_index& a, const T& b) +{ + using type = multi_index; + static_assert(T::size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] + b[i]; }); + return r; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(const multi_index& a, const T& b) +{ + using type = multi_index; + static_assert(T::size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] - b[i]; }); + return r; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index& a, const T& b) +{ + using type = multi_index; + static_assert(T::size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a[i] * b[i]; }); + return r; +} + +// multi_index = index_t * multi_index +template +CK_TILE_HOST_DEVICE constexpr auto operator*(index_t a, const multi_index& x) +{ + multi_index r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; }); + return r; +} + +// multi_index = multi_index * index_t +template +CK_TILE_HOST_DEVICE constexpr auto operator*(const multi_index& x, index_t a) +{ + return a * x; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp new file mode 100644 index 000000000..acf187cfc --- /dev/null +++ b/include/ck_tile/core/container/sequence.hpp @@ -0,0 +1,1114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/to_sequence.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +template +struct static_for; + +template +struct sequence; + +template +struct sequence_split; + +template +struct sequence_reverse; + +template +struct sequence_map_inverse; + +template +struct is_valid_sequence_map; + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_front(sequence); + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq); + +namespace impl { +// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element"); +template +using at_index_t = __type_pack_element; +} // namespace impl + +// we could implement as below, similiar to std. But let's reduce the symbol name... +// template< class T, T... Ints > +// class integer_sequence; + +template +struct sequence +{ + using type = sequence; + using value_type = index_t; + + CK_TILE_HOST_DEVICE static constexpr index_t size() { return sizeof...(Is); } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }; + + template + CK_TILE_HOST_DEVICE static constexpr auto get() + { + static_assert(I < size(), "wrong! I too large"); + return number...>{}>{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto get(number) + { + static_assert(I < size(), "wrong! I too large"); + return number()>{}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t at(index_t I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const index_t mData[size() + 1] = {Is..., 0}; + return mData[I]; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto at() + { + static_assert(I < size(), "wrong! I too large"); + return number...>{}>{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto at(number) + { + static_assert(I < size(), "wrong! I too large"); + return number()>{}; + } + + template + CK_TILE_HOST_DEVICE constexpr auto operator[](I i) const + { + return at(i); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto reorder_new_to_old(sequence /*new2old*/) + { + static_assert(sizeof...(Is) == sizeof...(IRs), + "wrong! reorder map should have the same size as sequence to be rerodered"); + + static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); + + return sequence{})...>{}; + } + + // MapOld2New is sequence<...> + template + CK_TILE_HOST_DEVICE static constexpr auto reorder_old_to_new(MapOld2New) + { + static_assert(MapOld2New::size() == size(), + "wrong! reorder map should have the same size as sequence to be rerodered"); + + static_assert(is_valid_sequence_map::value, "wrong! invalid reorder map"); + + return reorder_new_to_old(typename sequence_map_inverse::type{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto reverse() + { + return typename sequence_reverse::type{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto front() + { + static_assert(size() > 0, "wrong!"); + return get(number<0>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto back() + { + static_assert(size() > 0, "wrong!"); + return get(number{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto pop_front() { return sequence_pop_front(type{}); } + + CK_TILE_HOST_DEVICE static constexpr auto pop_back() { return sequence_pop_back(type{}); } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_front(sequence) + { + return sequence{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_front(number...) + { + return sequence{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_back(sequence) + { + return sequence{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto push_back(number...) + { + return sequence{}; + } + + // pickup element at index + template + CK_TILE_HOST_DEVICE static constexpr auto extract(number...) + { + return sequence{})...>{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto extract(sequence) + { + return sequence{})...>{}; + } + + // modify element at index "I" with value "X" + template + CK_TILE_HOST_DEVICE static constexpr auto modify(number, number) + { + static_assert(I < size(), "wrong!"); + + using seq_split = sequence_split; + constexpr auto seq_left = typename seq_split::left_type{}; + constexpr auto seq_right = typename seq_split::right_type{}.pop_front(); + + return seq_left.push_back(number{}).push_back(seq_right); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto transform(F f) + { + return sequence{}; + } + + CK_TILE_HOST_DEVICE static void print() + { + printf("sequence{size: %d, data: [", size()); + ((printf("%d ", Is)), ...); + printf("]}"); + } +}; + +namespace impl { +template +struct __integer_sequence; + +template +struct __integer_sequence +{ + using seq_type = sequence; +}; +} // namespace impl + +// similiar +template +using make_index_sequence = + typename __make_integer_seq::seq_type; + +// merge sequence +template +struct sequence_merge +{ + using type = typename sequence_merge::type>::type; +}; + +template +struct sequence_merge, sequence> +{ + using type = sequence; +}; + +template +struct sequence_merge +{ + using type = Seq; +}; + +// generate sequence +template +struct sequence_gen +{ + template + struct sequence_gen_impl + { + static constexpr index_t NRemainLeft = NRemain / 2; + static constexpr index_t NRemainRight = NRemain - NRemainLeft; + static constexpr index_t IMiddle = IBegin + NRemainLeft; + + using type = typename sequence_merge< + typename sequence_gen_impl::type, + typename sequence_gen_impl::type>::type; + }; + + template + struct sequence_gen_impl + { + static constexpr index_t Is = G{}(number{}); + using type = sequence; + }; + + template + struct sequence_gen_impl + { + using type = sequence<>; + }; + + using type = typename sequence_gen_impl<0, NSize, F>::type; +}; + +// arithmetic sequence +template +struct arithmetic_sequence_gen +{ + struct F + { + CK_TILE_HOST_DEVICE constexpr index_t operator()(index_t i) const + { + return i * Increment + IBegin; + } + }; + + using type0 = typename sequence_gen<(IEnd - IBegin) / Increment, F>::type; + using type1 = sequence<>; + + static constexpr bool kHasContent = + (Increment > 0 && IBegin < IEnd) || (Increment < 0 && IBegin > IEnd); + + using type = typename std::conditional::type; +}; + +template +struct arithmetic_sequence_gen<0, IEnd, 1> +{ + using type = make_index_sequence; +}; + +// uniform sequence +template +struct uniform_sequence_gen +{ + struct F + { + CK_TILE_HOST_DEVICE constexpr index_t operator()(index_t) const { return I; } + }; + + using type = typename sequence_gen::type; +}; + +// reverse inclusive scan (with init) sequence +template +struct sequence_reverse_inclusive_scan; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using old_scan = typename sequence_reverse_inclusive_scan, Reduce, Init>::type; + + static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.front()); + + using type = typename sequence_merge, old_scan>::type; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = sequence; +}; + +template +struct sequence_reverse_inclusive_scan, Reduce, Init> +{ + using type = sequence<>; +}; + +// split sequence +template +struct sequence_split +{ + static constexpr index_t NSize = Seq{}.size(); + + using range0 = typename arithmetic_sequence_gen<0, I, 1>::type; + using range1 = typename arithmetic_sequence_gen::type; + + using left_type = decltype(Seq::extract(range0{})); + using right_type = decltype(Seq::extract(range1{})); +}; + +#if 0 +// reverse sequence +template +struct sequence_reverse +{ + static constexpr index_t NSize = Seq{}.size(); + + using seq_split = sequence_split; + using type = typename sequence_merge< + typename sequence_reverse::type, + typename sequence_reverse::type>::type; +}; + +template +struct sequence_reverse> +{ + using type = sequence; +}; + +template +struct sequence_reverse> +{ + using type = sequence; +}; +#endif + +namespace impl { +template +struct seq_reverse; + +template +struct seq_reverse, Ns...> +{ + template + using element = impl::at_index_t...>; + using type = sequence::value...>; +}; +} // namespace impl + +template +struct sequence_reverse> + : impl::seq_reverse, Ns...> +{ +}; + +// template +// using sequence_reverse_t = typename sequence_reverse::type; + +#if 1 +template +struct sequence_reduce +{ + using type = typename sequence_reduce::type>::type; +}; + +template +struct sequence_reduce, sequence> +{ + using type = sequence; +}; + +template +struct sequence_reduce +{ + using type = Seq; +}; +#endif + +template +struct sequence_sort_impl +{ + template + struct sorted_sequence_merge_impl + { + static constexpr bool choose_left = LeftValues::front() < RightValues::front(); + + static constexpr index_t chosen_value = + choose_left ? LeftValues::front() : RightValues::front(); + static constexpr index_t chosen_id = choose_left ? LeftIds::front() : RightIds::front(); + + using new_merged_values = decltype(MergedValues::push_back(number{})); + using new_merged_ids = decltype(MergedIds::push_back(number{})); + + using new_left_values = typename std:: + conditional::type; + using new_left_ids = + typename std::conditional::type; + + using new_right_values = typename std:: + conditional::type; + using new_right_ids = + typename std::conditional::type; + + using merge = sorted_sequence_merge_impl; + // this is output + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + template + struct sorted_sequence_merge_impl, + sequence<>, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge_impl, + sequence<>, + RightValues, + RightIds, + MergedValues, + MergedIds, + Comp> + { + using merged_values = typename sequence_merge::type; + using merged_ids = typename sequence_merge::type; + }; + + template + struct sorted_sequence_merge + { + using merge = sorted_sequence_merge_impl, + sequence<>, + Comp>; + + using merged_values = typename merge::merged_values; + using merged_ids = typename merge::merged_ids; + }; + + static constexpr index_t nsize = Values::size(); + + using split_unsorted_values = sequence_split; + using split_unsorted_ids = sequence_split; + + using left_unsorted_values = typename split_unsorted_values::left_type; + using left_unsorted_ids = typename split_unsorted_ids::left_type; + using left_sort = sequence_sort_impl; + using left_sorted_values = typename left_sort::sorted_values; + using left_sorted_ids = typename left_sort::sorted_ids; + + using right_unsorted_values = typename split_unsorted_values::right_type; + using right_unsorted_ids = typename split_unsorted_ids::right_type; + using right_sort = sequence_sort_impl; + using right_sorted_values = typename right_sort::sorted_values; + using right_sorted_ids = typename right_sort::sorted_ids; + + using merged_sorted = sorted_sequence_merge; + + using sorted_values = typename merged_sorted::merged_values; + using sorted_ids = typename merged_sorted::merged_ids; +}; + +template +struct sequence_sort_impl, sequence, Compare> +{ + static constexpr bool choose_x = Compare{}(ValueX, ValueY); + + using sorted_values = typename std:: + conditional, sequence>::type; + using sorted_ids = + typename std::conditional, sequence>::type; +}; + +template +struct sequence_sort_impl, sequence, Compare> +{ + using sorted_values = sequence; + using sorted_ids = sequence; +}; + +template +struct sequence_sort_impl, sequence<>, Compare> +{ + using sorted_values = sequence<>; + using sorted_ids = sequence<>; +}; + +template +struct sequence_sort +{ + using unsorted_ids = typename arithmetic_sequence_gen<0, Values::size(), 1>::type; + using sort = sequence_sort_impl; + + // this is output + using type = typename sort::sorted_values; + using sorted2unsorted_map = typename sort::sorted_ids; +}; + +template +struct sequence_unique_sort +{ + template + struct sorted_sequence_uniquify_impl + { + static constexpr index_t current_value = RemainValues::front(); + static constexpr index_t current_id = RemainIds::front(); + + static constexpr bool is_unique_value = (current_value != UniquifiedValues::back()); + + using new_remain_values = decltype(RemainValues::pop_front()); + using new_remain_ids = decltype(RemainIds::pop_front()); + + using new_uniquified_values = + typename std::conditional{})), + UniquifiedValues>::type; + + using new_uniquified_ids = + typename std::conditional{})), + UniquifiedIds>::type; + + using uniquify = sorted_sequence_uniquify_impl; + + // this is output + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + template + struct sorted_sequence_uniquify_impl, + sequence<>, + UniquifiedValues, + UniquifiedIds, + Eq> + { + using uniquified_values = UniquifiedValues; + using uniquified_ids = UniquifiedIds; + }; + + template + struct sorted_sequence_uniquify + { + using uniquify = sorted_sequence_uniquify_impl, + sequence, + Eq>; + + using uniquified_values = typename uniquify::uniquified_values; + using uniquified_ids = typename uniquify::uniquified_ids; + }; + + using sort = sequence_sort; + using sorted_values = typename sort::type; + using sorted_ids = typename sort::sorted2unsorted_map; + + using uniquify = sorted_sequence_uniquify; + + // this is output + using type = typename uniquify::uniquified_values; + using sorted2unsorted_map = typename uniquify::uniquified_ids; +}; + +template +struct is_valid_sequence_map + : std::is_same::type, + typename sequence_sort>::type> +{ +}; + +template +struct sequence_map_inverse +{ + template + struct sequence_map_inverse_impl + { + static constexpr auto new_y2x = + WorkingY2X::modify(X2Y::get(number{}), number{}); + + using type = + typename sequence_map_inverse_impl:: + type; + }; + + template + struct sequence_map_inverse_impl + { + using type = WorkingY2X; + }; + + using type = + typename sequence_map_inverse_impl::type, + 0, + SeqMap::size()>::type; +}; + +template +CK_TILE_HOST_DEVICE constexpr bool operator==(sequence, sequence) +{ + return ((Xs == Ys) && ...); +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator!=(sequence x, sequence y) +{ + return !(x == y); +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs + Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs - Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs * Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs / Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator%(sequence, sequence) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong! inconsistent size"); + + return sequence<(Xs % Ys)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(sequence, number) +{ + return sequence<(Xs + Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(sequence, number) +{ + return sequence<(Xs - Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(sequence, number) +{ + return sequence<(Xs * Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(sequence, number) +{ + return sequence<(Xs / Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator%(sequence, number) +{ + return sequence<(Xs % Y)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator+(number, sequence) +{ + return sequence<(Y + Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator-(number, sequence) +{ + return sequence<(Y - Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator*(number, sequence) +{ + return sequence<(Y * Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(number, sequence) +{ + return sequence<(Y / Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator%(number, sequence) +{ + return sequence<(Y % Xs)...>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_front(sequence) +{ + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq) +{ + static_assert(Seq::size() > 0, "wrong! cannot pop an empty sequence!"); + return sequence_pop_front(Seq::reverse()).reverse(); +} + +template +CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...) +{ + return typename sequence_merge::type{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence) +{ + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_sequences(F f, sequence, sequence) +{ + static_assert(sequence::size() == sequence::size(), "Dim not the same"); + + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_sequences(F f, sequence, sequence, sequence) +{ + static_assert(sequence::size() == sequence::size() && + sequence::size() == sequence::size(), + "Dim not the same"); + + return sequence{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, number) +{ + return typename sequence_reverse_inclusive_scan::type{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, number) +{ + return reverse_inclusive_scan_sequence(Seq::pop_front(), Reduce{}, number{}) + .push_back(number{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto inclusive_scan_sequence(Seq, Reduce, number) +{ + return reverse_inclusive_scan_sequence(Seq{}.reverse(), Reduce{}, number{}).reverse(); +} + +// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add +// ResultSeq TargetSeq Reduce +template +struct sequence_exclusive_scan; + +template +struct sequence_exclusive_scan, sequence, Reduce> +{ + using old_scan = typename sequence_merge, + sequence{}.back())>>::type; + using type = typename sequence_exclusive_scan, Reduce>::type; +}; + +template +struct sequence_exclusive_scan, sequence, Reduce> +{ + using type = sequence; +}; + +template +struct sequence_exclusive_scan, sequence<>, Reduce> +{ + using type = sequence; +}; + +template +constexpr auto exclusive_scan_sequence(Seq, Reduce, number) +{ + // TODO: c++20 and later can pass in Reduce with a lambda expression + return typename sequence_exclusive_scan, Seq, Reduce>::type{}; +} + +template +constexpr auto prefix_sum_sequence(Seq) +{ + return typename sequence_exclusive_scan, + typename sequence_merge>::type, + plus>::type{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto pick_sequence_elements_by_ids(Seq, sequence /* ids */) +{ + return sequence{})...>{}; +} + +#if 1 +namespace detail { +template +struct pick_sequence_elements_by_mask_impl +{ + using new_work_seq = typename std::conditional::type; + + using type = + typename pick_sequence_elements_by_mask_impl::type; +}; + +template +struct pick_sequence_elements_by_mask_impl, sequence<>> +{ + using type = WorkSeq; +}; + +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto pick_sequence_elements_by_mask(Seq, Mask) +{ + static_assert(Seq::size() == Mask::size(), "wrong!"); + + return typename detail::pick_sequence_elements_by_mask_impl, Seq, Mask>::type{}; +} + +namespace detail { +template +struct modify_sequence_elements_by_ids_impl +{ + using new_work_seq = decltype(WorkSeq::modify(RemainIds::front(), RemainValues::front())); + + using type = + typename modify_sequence_elements_by_ids_impl::type; +}; + +template +struct modify_sequence_elements_by_ids_impl, sequence<>> +{ + using type = WorkSeq; +}; +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto modify_sequence_elements_by_ids(Seq, Values, Ids) +{ + static_assert(Values::size() == Ids::size() && Seq::size() >= Values::size(), "wrong!"); + + return typename detail::modify_sequence_elements_by_ids_impl::type{}; +} +#endif + +template +CK_TILE_HOST_DEVICE constexpr index_t +reduce_on_sequence(Seq, Reduce f, number /*initial_value*/) +{ + index_t result = Init; + + for(index_t i = 0; i < Seq::size(); ++i) + { + result = f(result, Seq::at(i)); + } + + return result; +} + +// TODO: a generic any_of for any container +template +CK_TILE_HOST_DEVICE constexpr bool sequence_any_of(Seq, F f) +{ + bool flag = false; + + for(index_t i = 0; i < Seq::size(); ++i) + { + flag = flag || f(Seq::at(i)); + } + + return flag; +} + +// TODO: a generic all_of for any container +template +CK_TILE_HOST_DEVICE constexpr bool sequence_all_of(Seq, F f) +{ + bool flag = true; + + for(index_t i = 0; i < Seq::size(); ++i) + { + flag = flag && f(Seq::at(i)); + } + + return flag; +} + +template +using sequence_merge_t = typename sequence_merge::type; + +template +using uniform_sequence_gen_t = typename uniform_sequence_gen::type; + +template +CK_TILE_HOST_DEVICE constexpr auto make_sequence(number...) +{ + return sequence{}; +} + +// F() returns index_t +// F use default constructor, so F cannot be lambda function +template +CK_TILE_HOST_DEVICE constexpr auto generate_sequence(F, number) +{ + return typename sequence_gen::type{}; +} + +// F() returns number<> +// F could be lambda function +template +CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F&& f, number) +{ + return unpack([&f](auto&&... xs) { return make_sequence(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +template +struct tuple; + +template +CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple...>) +{ + return sequence{}; +} + +namespace detail { +template +struct sorted_sequence_histogram; + +template +struct sorted_sequence_histogram, sequence> +{ + template + constexpr auto operator()(Histogram& h) + { + if constexpr(x < r) + { + h.template at() += 1; + sorted_sequence_histogram, sequence>{}(h); + } + else + { + h.template at() = 1; + sorted_sequence_histogram, sequence>{}(h); + } + } +}; + +template +struct sorted_sequence_histogram, sequence> +{ + template + constexpr auto operator()(Histogram& h) + { + if constexpr(x < r) + { + h.template at() += 1; + } + } +}; +} // namespace detail + +template +struct array; // declare for later use (array->seq utility) + +// SeqSortedSamples: <0, 2, 3, 5, 7>, SeqRange: <0, 3, 6, 9> -> SeqHistogram : <2, 2, 1> +template +CK_TILE_HOST_DEVICE constexpr auto histogram_sorted_sequence(SeqSortedSamples, sequence) +{ + constexpr auto bins = sizeof...(rs); // or categories + constexpr auto histogram = [&]() { + array h{0}; // make sure this can clear all element to zero + detail::sorted_sequence_histogram<0, SeqSortedSamples, sequence>{}(h); + return h; + }(); + + return TO_SEQUENCE(histogram, bins); +} + +template +CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number) +{ + using T = remove_cvref_t{}))>; + + return unpack([&f](auto&&... is) { return array{f(is)...}; }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/span.hpp b/include/ck_tile/core/container/span.hpp new file mode 100644 index 000000000..eeb1f226a --- /dev/null +++ b/include/ck_tile/core/container/span.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include +#include +#include + +namespace ck_tile { + +// implement the c++20 std::span, lightweight, non-owning reference to a sequence +// weather it is dynamic or static range. Or can be seen as a view of a contiguous sequence +// TODO: do we need in device consider this is pointer? +template +class span +{ + public: + using element_type = T; + using value_type = std::remove_cv_t; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using pointer = element_type*; + using const_pointer = const element_type*; + using reference = element_type&; + using const_reference = const element_type&; + using iterator = pointer; + using const_iterator = pointer; + + CK_TILE_HOST_DEVICE constexpr span() : span(nullptr, size_type{0}) {} + + CK_TILE_HOST_DEVICE constexpr span(pointer first, size_type count) : ptr_(first), size_(count) + { + } + + CK_TILE_HOST_DEVICE constexpr span(pointer first, pointer last) : span(first, last - first) {} + + template + CK_TILE_HOST_DEVICE constexpr span(element_type (&arr)[N]) noexcept : span(arr, N) + { + } + + template + CK_TILE_HOST_DEVICE constexpr span(std::array& arr) noexcept + : span(arr.data(), N) + { + } + + template + CK_TILE_HOST_DEVICE constexpr span(const Container& container) + : span(container.data(), container.size()) + { + } + + CK_TILE_HOST_DEVICE constexpr iterator begin() const noexcept { return ptr_; } + CK_TILE_HOST_DEVICE constexpr const_iterator cbegin() const noexcept { return begin(); } + + CK_TILE_HOST_DEVICE constexpr iterator end() const noexcept { return begin() + size(); } + CK_TILE_HOST_DEVICE constexpr const_iterator cend() const noexcept { return end(); } + + CK_TILE_HOST_DEVICE constexpr reference front() const { return *begin(); } + CK_TILE_HOST_DEVICE constexpr reference back() const { return *(--end()); } + + CK_TILE_HOST_DEVICE constexpr reference operator[](size_type idx) const + { + return *(begin() + idx); + } + CK_TILE_HOST_DEVICE constexpr pointer data() const noexcept { return ptr_; } + + CK_TILE_HOST_DEVICE constexpr size_type size() const noexcept { return size_; } + + private: + pointer ptr_; + size_type size_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/statically_indexed_array.hpp b/include/ck_tile/core/container/statically_indexed_array.hpp new file mode 100644 index 000000000..d6da50b62 --- /dev/null +++ b/include/ck_tile/core/container/statically_indexed_array.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +#if CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT == CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE + +template +using statically_indexed_array = tuple_array; + +#else + +// consider mark this struct as deprecated +template +using statically_indexed_array = array; + +#endif + +// consider always use ck_tile::array for this purpose +#if 0 +template +CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return statically_indexed_array(x, static_cast(xs)...); +} + +// make empty statically_indexed_array +template +CK_TILE_HOST_DEVICE constexpr auto make_statically_indexed_array() +{ + return statically_indexed_array(); +} +#endif +} // namespace ck_tile diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp new file mode 100644 index 000000000..a7dad5233 --- /dev/null +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/tuple.hpp" + +namespace ck_tile { + +#if CK_TILE_THREAD_BUFFER_DEFAULT == CK_TILE_THREAD_BUFFER_USE_TUPLE +template +using thread_buffer = tuple_array; + +template +CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts) +{ + return make_tuple(ts...); +} +#else + +#if 0 +template +using thread_buffer = array; + +template +CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts) +{ + return make_array(ts...); +} + +#endif + +// clang-format off +template +struct thread_buffer { + using value_type = remove_cvref_t; + static constexpr index_t N = N_; + + value_type data[N]; + + // TODO: this ctor can't ignore + CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {} + CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {} + + CK_TILE_HOST_DEVICE static constexpr auto size() { return N; } + CK_TILE_HOST_DEVICE auto & get() {return data; } + CK_TILE_HOST_DEVICE const auto & get() const {return data; } + CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; } + CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; } + CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); } + CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); } + CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible + CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); } + CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); } + template CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); } + template CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } + template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return get(I); } + template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return get(I); } + + template ::value, bool>::type = false> + CK_TILE_HOST_DEVICE constexpr auto _get_as() const + { + using X = remove_cvref_t; + + constexpr index_t kSPerX = vector_traits::vector_size; + static_assert(N % kSPerX == 0); + + union { + thread_buffer data {}; + // tuple_array sub_data; + value_type sub_data[N]; + } vx; + static_for<0, N, 1>{}( + [&](auto j) { vx.sub_data[j] = data[j]; }); + return vx.data; + } + + template ::value, bool>::type = false> + CK_TILE_HOST_DEVICE const constexpr remove_reference_t _get_as(number is) const + { + using X = remove_cvref_t; + + constexpr index_t kSPerX = vector_traits::vector_size; + + union { + X_ data {}; + tuple_array sub_data; + } vx; + static_for<0, kSPerX, 1>{}( + [&](auto j) { vx.sub_data(j) = operator[]((is * number{}) + j); }); + return vx.data; + } + +#if 0 + template ::value, bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void _set_as(number is, X_ x) + { + using X = remove_cvref_t; + + constexpr index_t kSPerX = vector_traits::vector_size; + + union { + X_ data; + tuple_array sub_data; + } vx {x}; + + static_for<0, kSPerX, 1>{}( + [&](auto j) { operator()((is * number{}) + j) = vx.sub_data[j]; }); + } +#endif + + +#define TB_COMMON_AS() \ + static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \ + constexpr int vx = sizeof(value_type) * N / sizeof(Tx) + + template + CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS(); + return reinterpret_cast&>(data);} + template + CK_TILE_HOST_DEVICE constexpr auto get_as() const {TB_COMMON_AS(); + if constexpr(sizeof(value_type) <= 1 ) + return _get_as(); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future + else + return reinterpret_cast&>(data);} + template + CK_TILE_HOST_DEVICE auto & get_as(number) {TB_COMMON_AS(); + return reinterpret_cast&>(data).get(number{});} + template + CK_TILE_HOST_DEVICE constexpr auto get_as(number) const {TB_COMMON_AS(); + if constexpr(sizeof(value_type) <= 1 ) + return _get_as(number{}); // TODO: current compiler for 8bit data need use union to get data back, should fix in the future + else + return reinterpret_cast&>(data).get(number{});} + + template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) + { TB_COMMON_AS(); reinterpret_cast&>(data).at(i) = x; } + template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) + { TB_COMMON_AS(); reinterpret_cast&>(data).at(number{}) = x; } + +#undef TB_COMMON_AS +}; +// clang-format on + +template +struct vector_traits; + +// specialization for array +template +struct vector_traits> +{ + using scalar_type = T; + static constexpr index_t vector_size = N; +}; + +#endif + +} // namespace ck_tile diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp new file mode 100644 index 000000000..cb8c2c70c --- /dev/null +++ b/include/ck_tile/core/container/tuple.hpp @@ -0,0 +1,781 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include +#include + +#ifndef CK_TILE_TUPLE_IMPL +#define CK_TILE_TUPLE_IMPL 1 +#endif + +namespace ck_tile { + +namespace impl { +template +struct tuple_array_impl; +} + +template +using tuple_array = typename impl::tuple_array_impl::type; + +namespace impl { + +// the place where content is stored +template > +struct tuple_object +{ +}; + +template +struct tuple_object +{ + CK_TILE_HOST_DEVICE constexpr tuple_object() {} +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U&&) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(const U&) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U&) + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template , tuple_object>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_object(U&&) + { + } +#endif +}; + +template +struct tuple_object +{ + CK_TILE_HOST_DEVICE constexpr tuple_object() : element{} {} +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward(e)) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(const U& e) : element(e) + { + } + template + CK_TILE_HOST_DEVICE constexpr tuple_object(U& e) : element(e) + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template , tuple_object>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_object(U&& e) : element(std::forward(e)) + { + } +#endif + T element; +}; + +// NOTE: we return a instance(not a reference) if content is empty +template +CK_TILE_HOST_DEVICE constexpr T getv(const tuple_object&) +{ + return {}; +} + +template +CK_TILE_HOST_DEVICE constexpr const T& getv(const tuple_object& x) +{ + return x.element; +} + +template +CK_TILE_HOST_DEVICE constexpr T& getv(tuple_object& x) +{ + return x.element; +} + +template +CK_TILE_HOST_DEVICE constexpr T&& getv(tuple_object&& x) +{ + return static_cast(x.element); +} + +template +struct tuple_base; + +template +struct tuple_base, T...> : tuple_object... +{ + CK_TILE_HOST_DEVICE constexpr tuple_base() = default; + +#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST +#define _ILE() (std::initializer_list{}.size() - 1) + template + CK_TILE_HOST_DEVICE constexpr tuple_base(std::initializer_list us) + : tuple_object(static_cast(*(us.begin() + (I >= _ILE() ? _ILE() : I))))... + { + } +#undef _ILE +#endif + +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&&... u) + : tuple_object(std::forward(u))... + { + } + + template + CK_TILE_HOST_DEVICE constexpr explicit tuple_base(const U&... u) : tuple_object(u)... + { + } + + template + CK_TILE_HOST_DEVICE constexpr explicit tuple_base(U&... u) : tuple_object(u)... + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base, U...>&& u) + : tuple_object(getv(static_cast&&>(u)))... + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple_base(const tuple_base, U...>& u) + : tuple_object(getv(static_cast&>(u)))... + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple_base(tuple_base, U...>& u) + : tuple_object(getv(static_cast&>(u)))... + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template , tuple_base>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_base(U&& u) : tuple_object(std::forward(u))... + { + } + + template = 2, bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple_base(U&&... u) : tuple_object(std::forward(u))... + { + static_assert(sizeof...(I) == sizeof...(T) && sizeof...(I) == sizeof...(U), + "wrong! inconsistent size"); + } + +#endif +}; +} // namespace impl + +template +struct tuple : impl::tuple_base, T...> +{ + CK_TILE_HOST_DEVICE + static constexpr auto size() { return sizeof...(T); } + using base = impl::tuple_base, T...>; + CK_TILE_HOST_DEVICE constexpr tuple() = default; + +#if CK_TILE_TUPLE_CTOR_WITH_INITIALIZER_LIST + template + CK_TILE_HOST_DEVICE constexpr tuple(std::initializer_list us) : base(us) + { + } +#endif + +#if CK_TILE_TUPLE_IMPL == 0 + template + CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward(u)...) + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple(const U&... u) : base(u...) + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple(U&... u) : base(u...) + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple(tuple&& u) + : base(static_cast, U...>&&>(u)) + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple(const tuple& u) + : base(static_cast, U...>&>(u)) + { + } + + template + CK_TILE_HOST_DEVICE constexpr tuple(tuple& u) + : base(static_cast, U...>&>(u)) + { + } +#elif CK_TILE_TUPLE_IMPL == 1 + template < + typename U, + typename std::enable_if, tuple>::value, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple(U&& u) : base(std::forward(u)) + { + } + + template = 2, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr tuple(U&&... u) : base(std::forward(u)...) + { + } +#endif + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + bool flag = true; + + static_for<0, sizeof...(T), 1>{}([&flag](auto i) { + flag &= is_static_v>>; + }); + + return flag; + } + +#define TP_COM_() static_assert(I < size(), "wrong! out of range") + // clang-format off + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) { TP_COM_(); return get(); } + + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at() { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number) { TP_COM_(); return get(); } + + template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator[](number) const { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) operator()(number) { TP_COM_(); return get(); } // TODO: compatible + + // below function should be used under tuple_array<> type, no extra check will perform here + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() { return reinterpret_cast&>(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as() const { return reinterpret_cast&>(*this); } + // below index is for index *AFTER* type convert, not before + //template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) { TP_COM_(); return reinterpret_cast&>(*this).at(i); } + //template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(index_t i) const { TP_COM_(); return reinterpret_cast&>(*this).at(i); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get_as(number) const { TP_COM_(); return reinterpret_cast&>(*this).at(number{}); } + + // template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(i) = x; } + template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) { TP_COM_(); reinterpret_cast&>(*this).at(number{}) = x; } + + // clang-format on +#undef TP_COM_ +}; + +template +struct vector_traits; + +// specialization for array +template +struct vector_traits> +{ + using scalar_type = __type_pack_element<0, T...>; + static constexpr index_t vector_size = sizeof...(T); +}; + +// template +// CK_TILE_HOST_DEVICE constexpr +// tuple +// make_tuple(T const&... t) +// { +// return {t...}; +// } +template +CK_TILE_HOST_DEVICE constexpr bool operator==(const tuple& a, const tuple& b) +{ + bool same = true; + + static_for<0, sizeof...(Xs), 1>{}([&](auto i) { + if(a[i] != b[i]) + { + same = false; + } + }); + + return same; +} + +template +CK_TILE_HOST_DEVICE constexpr bool operator!=(const tuple& a, const tuple& b) +{ + return !(a == b); +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs&&... xs) +{ + // here xs is always a lvalue as function arg + // Xs may deduced as (e.g try to pass in a integer in following cases) + // 1). if pass in a rvalue (like function return or int{}) -> Xs is "int" + // 2). if pass in a const lvalue -> Xs is "const int &" + // 3). if pass in a non-const lvalue -> Xs is "int &" + // so the return type of std::forward will dependes on Xs + // 1). std::forward -> int&& + // 2). std::forward -> const int& + // 3). std::forward -> int& + return tuple...>(std::forward(xs)...); +} + +// https://en.cppreference.com/w/cpp/utility/tuple/tie +template +constexpr tuple tie(Args&... args) noexcept +{ + return {args...}; +} + +template +struct tuple_concat; + +template +struct tuple_concat, tuple> +{ + using type = tuple; +}; + +namespace impl { +// be very careful using this type (because we want the internal type) +// template deduction will fail if infering the inner type +// e.g. +// template using some_wrapper = typename tuple_array_impl::type; +// template void foo(const some_wrapper&) {} +// -> compiler will fail to deduce this type, because this is under non-deduced context +// (https://en.cppreference.com/w/cpp/language/template_argument_deduction, "Non-deduced +// contexts") +// +// -> use this instead +// template void foo(const Tup&) {} +template +struct tuple_array_impl +{ + using type = typename tuple_concat::type, + typename tuple_array_impl::type>::type; +}; + +template +struct tuple_array_impl +{ + using type = tuple<>; +}; + +template +struct tuple_array_impl +{ + using type = tuple; +}; +} // namespace impl + +template +CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F&& f, number) +{ + return unpack([&f](auto&&... is) { return make_tuple(f(is)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto generate_tie(F&& f, number) +{ + return unpack([&f](auto&&... is) { return tie(f(is)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue) +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple_of_reference(const tuple& tx, + const tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return tuple{std::forward(zs)...}; }, + tx, + ty); +} + +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple& tx, const tuple& ty) +{ + return unpack2( + [&](auto... zs) { return tuple{std::forward(zs)...}; }, + tx, + ty); +} + +// Support any number of tuples to concat (also 1) +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple& tx) +{ + return tx; +} + +template +CK_TILE_HOST_DEVICE constexpr auto concat_tuple(const tuple& tx, const Tuples&... tuples) +{ + return concat_tuple(tx, concat_tuple(tuples...)); +} + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples_impl(F f, const X& x, sequence) +{ + return make_tuple(f(x.at(number{}))...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, sequence) +{ + return make_tuple(f(x.at(number{}), y.at(number{}))...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence) +{ + return make_tuple(f(x.at(number{}), y.at(number{}), z.at(number{}))...); +} + +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x) +{ + return detail::transform_tuples_impl( + f, x, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y) +{ + return detail::transform_tuples_impl( + f, x, y, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) +{ + return detail::transform_tuples_impl( + f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); +} + +// By default unroll to the flatten +template +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple<>& t) +{ + return t; +} + +template +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const T& t) +{ + return make_tuple(t); +} + +template +CK_TILE_HOST_DEVICE constexpr auto unroll_nested_tuple(const tuple& t) +{ + if constexpr(Depth == MaxDepth) + { + return t; + } + else + { + return unpack( + [&](auto&&... ts) { + return concat_tuple(unroll_nested_tuple(ts)...); + }, + t); + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto tuple_reverse(const tuple& t) +{ + return generate_tuple( + [&](auto i) { + using Idx = number::size() - i - 1>; + return t.at(Idx{}); + }, + number::size()()>{}); +} + +// Reduce tuple values in specific range using Function +template +CK_TILE_HOST_DEVICE constexpr auto tuple_reduce(F&& f, const tuple& t) +{ + static_assert(Idx < End, "Wrong parameters for tuple_reduce"); + if constexpr(Idx + 1 == End) + { + return t.at(number{}); + } + else + { + return f(t.at(number{}), tuple_reduce(f, t)); + } +} + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template +CK_TILE_HOST_DEVICE constexpr auto is_nested_tuple(const tuple&) +{ + return (is_detected::value || ...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const T&) +{ + return depth; +} + +template +CK_TILE_HOST_DEVICE constexpr auto tuple_depth(const tuple&) +{ + return max(tuple_depth(Ts{})...); +} + +template +CK_TILE_HOST_DEVICE constexpr auto to_array_of_array(tuple t_of_s) +{ + constexpr index_t n0 = sizeof...(Seqs); + + constexpr index_t max_n1 = [&] { + index_t max_n1_ = 0; + + static_for<0, n0, 1>{}([&](auto i0) { + constexpr index_t n1 = t_of_s[i0].size(); + + max_n1_ = max_n1_ < n1 ? n1 : max_n1_; + }); + + return max_n1_; + }(); + + array, n0> a_of_a{{-1}}; + + static_for<0, n0, 1>{}([&](auto i0) { + constexpr index_t n1 = t_of_s[i0].size(); + + static_for<0, n1, 1>{}([&](auto i1) { a_of_a(i0)(i1) = t_of_s[i0][i1]; }); + }); + + return a_of_a; +} + +// Here should use MultiIndex, instead of tuple, although the former +// is the alias of the latter. This is because compiler cannot infer the NSize if +// using MultiIndex +// TODO: how to fix this? +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; }); + return y; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; }); + return y; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] + y[i]; }); + return r; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] - y[i]; }); + return r; +} + +template ::value && !std::is_floating_point::value, bool> = + false> +CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = x[i] * y[i]; }); + return r; +} + +// MultiIndex = scalar * MultiIndex +template < + typename... Xs, + typename Y, + std::enable_if_t::value || std::is_floating_point::value, bool> = false> +CK_TILE_HOST_DEVICE constexpr auto operator*(Y a, const tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r[i] = a * x[i]; }); + return r; +} + +// MultiIndex = MultiIndex * scalar +template < + typename... Xs, + typename Y, + std::enable_if_t::value || std::is_floating_point::value, bool> = false> +CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple& x, Y a) +{ + return a * x; +} + +template +CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple& x, const tuple& y) +{ + static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!"); + constexpr index_t NSize = sizeof...(Xs); + return generate_tuple([&](auto i) { return x[i] / y[i]; }, number{}); +} + +} // namespace ck_tile + +#include +// WARNING: needed by compiler for C++ structured binding support only, don't use this +namespace std { + +template +struct tuple_size> : std::integral_constant +{ +}; + +template +struct tuple_element> : std::tuple_element> +{ +}; + +template +struct tuple_size> : std::integral_constant +{ +}; + +template +struct tuple_element> + : std::tuple_element> +{ +}; + +} // namespace std + +#if 1 +#define TO_TUPLE_OF_NUMBER(a, n) \ + _Pragma("clang diagnostic push") _Pragma( \ + "clang diagnostic ignored \"-Wc++20-extensions\"")[a]( \ + ck_tile::sequence) \ + { \ + return ck_tile::tuple{}]>...>{}; \ + } \ + (ck_tile::make_index_sequence{}) _Pragma("clang diagnostic pop") +#else +#define TO_TUPLE_OF_NUMBER(arr, n_) \ + [&arr, n_] { \ + static_assert(arr.size() >= n_, "wrong! out of bound"); \ + \ + static_assert(n_ < 7, "not implemented"); \ + \ + if constexpr(n_ == 0) \ + { \ + return ck_tile::tuple<>{}; \ + } \ + else if constexpr(n_ == 1) \ + { \ + return ck_tile::tuple>{}; \ + } \ + else if constexpr(n_ == 2) \ + { \ + return ck_tile::tuple, number>{}; \ + } \ + else if constexpr(n_ == 3) \ + { \ + return ck_tile::tuple, number, number>{}; \ + } \ + else if constexpr(n_ == 4) \ + { \ + return ck_tile:: \ + tuple, number, number, number>{}; \ + } \ + else if constexpr(n_ == 5) \ + { \ + return ck_tile::tuple, \ + number, \ + number, \ + number, \ + number>{}; \ + } \ + else if constexpr(n_ == 6) \ + { \ + return ck_tile::tuple, \ + number, \ + number, \ + number, \ + number, \ + number>{}; \ + } \ + }() +#endif diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp new file mode 100644 index 000000000..071387163 --- /dev/null +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -0,0 +1,342 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include + +#pragma once + +namespace ck_tile { + +enum class bf16_rounding_mode +{ + standard = 0, // rtn + truncate_with_nan, + truncate, +}; + +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant = {}); + +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant = {}); + +CK_TILE_HOST_DEVICE +constexpr float bf16_to_float_raw(uint16_t x); + +CK_TILE_HOST_DEVICE +constexpr double bf16_to_double_raw(uint16_t x); + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +// HIP use __hip_bfloat16 as struct +struct alignas(2) bfloat16_t +{ + using raw_type = uint16_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static constexpr bfloat16_t bit_cast(raw_type x) + { + bfloat16_t y; + y.data = x; + return y; + } + + // constructor + constexpr bfloat16_t() : data() {} + + // construct from float + CK_TILE_HOST_DEVICE + explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {} + + // construct from double + CK_TILE_HOST_DEVICE + explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {} + + // construct from int + CK_TILE_HOST_DEVICE + explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast(x))) {} + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit constexpr bfloat16_t(const unsigned int& x) + : data(float_to_bf16_raw(static_cast(x))) + { + } + + // cast to float + CK_TILE_HOST_DEVICE + explicit constexpr operator float() const { return bf16_to_float_raw(data); } + + // cast to float + CK_TILE_HOST_DEVICE + explicit constexpr operator double() const { return bf16_to_double_raw(data); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit constexpr operator int() const { return static_cast(bf16_to_float_raw(data)); } + + // internal access + CK_TILE_HOST_DEVICE + constexpr raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + constexpr raw_type get() const { return data; } +}; +template +struct native_t; + +template <> +struct native_t +{ + using type = ushort; +}; +using bf16_t = bfloat16_t; +using bf16_raw_t = typename bf16_t::raw_type; +#else +using bfloat16_t = ushort; +using bf16_t = bfloat16_t; +using bf16_raw_t = uint16_t; +#endif +// round to nearest +CK_TILE_HOST_DEVICE +constexpr uint16_t float_to_bf16_rtn_raw(float f) +{ + union + { + float fp32; + uint32_t int32; + } u = {f}; + if(~u.int32 & 0x7f800000) + { + // When the exponent bits are not all 1s, then the value is zero, normal, + // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus + // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd). + // This causes the bfloat16's mantissa to be incremented by 1 if the 16 + // least significant bits of the float mantissa are greater than 0x8000, + // or if they are equal to 0x8000 and the least significant bit of the + // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when + // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already + // has the value 0x7f, then incrementing it causes it to become 0x00 and + // the exponent is incremented by one, which is the next higher FP value + // to the unrounded bfloat16 value. When the bfloat16 value is subnormal + // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up + // to a normal value with an exponent of 0x01 and a mantissa of 0x00. + // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F, + // incrementing it causes it to become an exponent of 0xFF and a mantissa + // of 0x00, which is Inf, the next higher value to the unrounded value. + u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even + } + else if(u.int32 & 0xffff) + { + // When all of the exponent bits are 1, the value is Inf or NaN. + // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero + // mantissa bit. Quiet NaN is indicated by the most significant mantissa + // bit being 1. Signaling NaN is indicated by the most significant + // mantissa bit being 0 but some other bit(s) being 1. If any of the + // lower 16 bits of the mantissa are 1, we set the least significant bit + // of the bfloat16 mantissa, in order to preserve signaling NaN in case + // the bloat16's mantissa bits are all 0. + u.int32 |= 0x10000; // Preserve signaling NaN + } + return uint16_t(u.int32 >> 16); +} + +// Truncate instead of rounding, preserving SNaN +CK_TILE_HOST_DEVICE +constexpr uint16_t float_to_bf16_truc_nan_raw(float f) +{ + union + { + float fp32; + uint32_t int32; + } u = {f}; + return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff)); +} + +// Fast truncate instead of rounding, RTZ +CK_TILE_HOST_DEVICE +constexpr uint16_t float_to_bf16_truc_raw(float f) +{ + union + { + float fp32; + uint32_t int32; + } u = {f}; + return uint16_t(u.int32 >> 16); +} + +template +CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant) +{ + if constexpr(rounding == bf16_rounding_mode::standard) + return float_to_bf16_rtn_raw(f); + else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan) + return float_to_bf16_truc_nan_raw(f); + else + return float_to_bf16_truc_raw(f); +} + +template +CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant) +{ + return float_to_bf16_raw(static_cast(f), constant{}); +} + +CK_TILE_HOST_DEVICE +constexpr float bf16_to_float_raw(uint16_t x) +{ + union + { + uint32_t int32; + float fp32; + } u = {uint32_t(x) << 16}; + return u.fp32; +} + +CK_TILE_HOST_DEVICE +constexpr double bf16_to_double_raw(uint16_t x) +{ + return static_cast(bf16_to_float_raw(x)); +} + +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) +{ + return bit_cast(float_to_bf16_raw(f, constant{})); +} + +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant = {}) +{ + return bit_cast(double_to_bf16_raw(f, constant{})); +} + +CK_TILE_HOST_DEVICE +constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast(x)); } + +CK_TILE_HOST_DEVICE +constexpr double bf16_to_double(bfloat16_t x) { return static_cast(bf16_to_float_raw(x)); } + +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant = {}) +{ + return bit_cast(float_to_bf16_raw(static_cast(f), constant{})); +} + +CK_TILE_HOST_DEVICE +constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast(static_cast(x)); } + +template +struct numeric; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() + { + return bit_cast(static_cast(0x0080)); + } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest() + { + return bit_cast(static_cast(0xff7f)); + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() + { + return bit_cast(static_cast(0x7f7f)); + } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon() + { + return bit_cast(static_cast(0x1000)); + } + + // maximum rounding error + // maximum rounding error + // bin : f edcba 9876543210 + // bits: s eeeeeeee mmmmmmm + // 0 01111110 0000000 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() + { + return bit_cast(static_cast(0x3f00)); + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity() + { + return bit_cast(static_cast(0x7f80)); + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN() + { + return bit_cast(static_cast(0x7FFF)); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN() + { + return bit_cast(static_cast(0x7FFF)); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min() + { + return bit_cast(static_cast(0x0001)); + } + CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero() + { + return bit_cast(static_cast(0)); + } +}; + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t) +#endif + +// math +CK_TILE_HOST_DEVICE +bfloat16_t abs(const bfloat16_t& x) +{ + return bit_cast(static_cast(bit_cast(x) & 0x7fff)); +} + +CK_TILE_HOST_DEVICE +bool isnan(const bfloat16_t& x) +{ + uint16_t xx = bit_cast(x); + return (xx & 0x7FFF) > 0x7C00; +} + +CK_TILE_DEVICE +bfloat16_t sqrt(bfloat16_t x) +{ + return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); +}; + +CK_TILE_DEVICE +bfloat16_t exp(bfloat16_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +bfloat16_t exp2(bfloat16_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +bfloat16_t log(bfloat16_t x) { return static_cast(__logf(static_cast(x))); }; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp new file mode 100644 index 000000000..bad1009f2 --- /dev/null +++ b/include/ck_tile/core/numeric/float8.hpp @@ -0,0 +1,871 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/utility/random.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include +#include + +#pragma once + +namespace ck_tile { + +// fp8 rounding modes +// use standard for rounding to nearest, the faster one +// use stochastic for stochastic rounding, helps to avoid error accumulation +enum class fp8_rounding_mode +{ + standard = 0, + stochastic +}; + +/* + * ______________NANOO_________________ | ______________IEEE________________ + * e4m3 e5m2 | e4m3 e5m2 + * bias : 8 16 | 7 15 + * inf : 1.0000.000 1.00000.00 | N/A s.11111.00 + * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11} + * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00 + * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344) + * Max(snorm): s.0000.111 s.00000.11 | s.0000.111(448) s.00000.11(57344) + * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05 + * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00 + * 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05) + * Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01 + * 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05) + */ + +template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant = {}); + +template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant = {}); + +CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t); +CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t); + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +struct alignas(1) float8_e4m3_t +{ + static constexpr int exponent = 4; + static constexpr int mantissa = 3; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 1 << (exponent - 1); // NANOO +#else + static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE +#endif + using raw_type = uint8_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static constexpr float8_e4m3_t bit_cast(raw_type x) + { + float8_e4m3_t y; + y.data = x; + return y; + } + + // constructor + constexpr float8_e4m3_t() : data() {} + + // construct from float + CK_TILE_HOST_DEVICE + explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {} + + // construct from int + CK_TILE_HOST_DEVICE + explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast(x))) + { + } + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit constexpr float8_e4m3_t(const unsigned int& x) + : data(float_to_fp8_raw(static_cast(x))) + { + } + + // cast to float + CK_TILE_HOST_DEVICE + explicit constexpr operator float() const { return fp8_to_float_raw(data); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit constexpr operator int() const { return static_cast(fp8_to_float_raw(data)); } + + // internal access + CK_TILE_HOST_DEVICE + constexpr raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + constexpr raw_type get() const { return data; } +}; +using fp8_t = float8_e4m3_t; +using fp8_raw_t = typename fp8_t::raw_type; + +struct alignas(1) float8_e5m2_t +{ + static constexpr int exponent = 5; + static constexpr int mantissa = 2; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 1 << (exponent - 1); // NANOO +#else + static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE +#endif + using raw_type = uint8_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static constexpr float8_e5m2_t bit_cast(raw_type x) + { + float8_e5m2_t y; + y.data = x; + return y; + } + + // constructor + constexpr float8_e5m2_t() : data() {} + + // construct from float + CK_TILE_HOST_DEVICE + explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {} + + // construct from int + CK_TILE_HOST_DEVICE + explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast(x))) + { + } + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit constexpr float8_e5m2_t(const unsigned int& x) + : data(float_to_bf8_raw(static_cast(x))) + { + } + + // cast to float + CK_TILE_HOST_DEVICE + explicit constexpr operator float() const { return bf8_to_float_raw(data); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit constexpr operator int() const { return static_cast(bf8_to_float_raw(data)); } + + // internal access + CK_TILE_HOST_DEVICE + constexpr raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + constexpr raw_type get() const { return data; } +}; +using bf8_t = float8_e5m2_t; +using bf8_raw_t = typename bf8_t::raw_type; + +template +struct native_t; + +template <> +struct native_t +{ + using type = _BitInt(8); +}; + +template <> +struct native_t +{ + using type = unsigned _BitInt(8); +}; + +#else +using fp8_t = _BitInt(8); +using fp8_raw_t = uint8_t; +using bf8_t = unsigned _BitInt(8); +using bf8_raw_t = uint8_t; +#endif + +// below is sw fp8 conversion, not utilizing hw instruction +namespace impl { + +template +CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = numeric_traits::exp; + constexpr int out_mant = numeric_traits::mant; + + // original type exponent/mantissa layout + constexpr int in_exp = numeric_traits::exp; + constexpr int in_mant = numeric_traits::mant; + + int exponent, bias; + uint32_t head, mantissa, sign; + // nan code is same for float and half +#if CK_TILE_USE_CUSTOM_DATA_TYPE + constexpr Y nan_code = + numeric::quiet_NaN(); // __builtin_bit_cast(Y, static_cast(0x80)); +#else + constexpr Y nan_code = 0x80; +#endif + + constexpr uint32_t nan_mask = numeric_traits::nan_mask; + + // convert to bitwise + using T_bitwise = typename numeric_traits::bitwise_type; + T_bitwise x_bitwise = *(reinterpret_cast(&x)); + + // unpack the input, depends on datatype + head = x_bitwise & numeric_traits::head_mask; + mantissa = x_bitwise & numeric_traits::mant_mask; + exponent = (head >> in_mant) & numeric_traits::exp_mask; + sign = head >> (in_exp + in_mant); + bias = numeric_traits::bias; + + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); + + if constexpr(negative_zero_nan) + { + if((x_bitwise & nan_mask) == nan_mask) + return nan_code; + } + else + { + if((x_bitwise & nan_mask) == nan_mask) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + + // check if x is 0.0 + if(x_bitwise == 0) + return __builtin_bit_cast(Y, static_cast(0)); + + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again3 + + // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits + const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // out_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, out_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. +In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = out_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= out_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = out_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << in_mant); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == + (1 << (in_mant - out_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << in_mant); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + out_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + bool odd = + mantissa & + (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(out_exponent == 0) + { + if((1 << in_mant) & mantissa) + { + out_exponent = 1; // denormal overflow to become normal, promote exponent + // No need to make 1 implicit now as it will be addressed later + } + } + else + { + if((1 << (in_mant + 1)) & mantissa) + { + mantissa >>= 1; + out_exponent++; + // No need to make 1 implicit now as it will be addressed later + } + } + + mantissa >>= (in_mant - out_mant); + + if(out_exponent > max_exp) + { + if(clip) + { + mantissa = (1 << out_mant) - 1; + out_exponent = max_exp; + } + else + { + return __builtin_bit_cast(Y, static_cast(signed_inf)); + } + } + + // check if x is 0.0 or -0.0 + if(out_exponent == 0 && mantissa == 0) + return __builtin_bit_cast( + Y, static_cast(negative_zero_nan ? 0 : (sign << (out_exp + out_mant)))); + mantissa &= (1 << out_mant) - 1; + return __builtin_bit_cast(Y, + static_cast((sign << (out_exp + out_mant)) | + (out_exponent << out_mant) | mantissa)); +} + +template +CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = numeric_traits::exp; + constexpr int in_mant = numeric_traits::mant; + + // resulting type exponent/mantissa layout + constexpr int out_exp = numeric_traits::exp; + constexpr int out_mant = numeric_traits::mant; + uint8_t x_raw = __builtin_bit_cast(uint8_t, x); + + // prepare the codes + constexpr uint8_t nan_code = 0x80; + Y Inf, NegInf, NaN, Neg0; + using T_bitwise = typename numeric_traits::bitwise_type; + + constexpr T_bitwise Inf_bitwise = numeric_traits::Inf; + constexpr T_bitwise NegInf_bitwise = numeric_traits::NegInf; + constexpr T_bitwise NaN_bitwise = numeric_traits::NaN; + constexpr T_bitwise Neg0_bitwise = numeric_traits::Neg0; + + Inf = *(reinterpret_cast(&Inf_bitwise)); + NegInf = *(reinterpret_cast(&NegInf_bitwise)); + NaN = *(reinterpret_cast(&NaN_bitwise)); + Neg0 = *(reinterpret_cast(&Neg0_bitwise)); + + // check if x is 0.0 + if(x_raw == 0) + return static_cast(0); + + // unpack the input + uint32_t sign = x_raw >> (in_exp + in_mant); + uint32_t mantissa = x_raw & ((1 << in_mant) - 1); + int exponent = (x_raw & 0x7F) >> in_mant; + + constexpr int exp_low_cutoff = + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; + + if constexpr(negative_zero_nan) + { + if(x_raw == nan_code) + return NaN; + } + else + { + if(x_raw == nan_code) + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if((numeric_traits::mant == 10) && (numeric_traits::mant == 2) && !negative_zero_nan) + { + retval = x_raw; + retval <<= 8; + return *(reinterpret_cast(&retval)); + } + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - in_mant); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << in_mant) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= out_mant - in_mant; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << out_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return *(reinterpret_cast(&retval)); +} + +template +CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng) +{ + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); + + return run_cast_to_f8(x, rng); +} + +template +CK_TILE_HOST_DEVICE Y cast_from_f8(X x) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported."); + + return run_cast_from_f8(x); +} +} // namespace impl + +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_sr_raw(float x) +{ + constexpr int seed = 42; + uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float max_fp8 = 240.0f; + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; + return bit_cast(impl::cast_to_f8(x, rng)); +#endif +} + +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_sr_raw(float x) +{ + constexpr int seed = 42; + uint32_t rng = prand_generator_t{}(reinterpret_cast(&x), x); +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + return val.i8val[0]; // little endian +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::stochastic; + return bit_cast(impl::cast_to_f8(x, rng)); +#endif +} + +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_rtn_raw(float x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float max_fp8 = 240.0f; + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return bit_cast(impl::cast_to_f8(x, rng)); +#endif +} +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + union + { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // not endian independent + } val; + val.fval = x; + uint32_t ival = 0; + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 + val.i32val = ival; + return val.i8val[0]; +#else + constexpr bool negative_zero_nan = true; + constexpr bool clip = true; + constexpr fp8_rounding_mode rm = fp8_rounding_mode::standard; + constexpr uint32_t rng = 0; + return bit_cast(impl::cast_to_f8(x, rng)); +#endif +} + +// clang-format off +template +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant) +{ + if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x); + else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x); + else return fp8_raw_t{0}; +} + +template +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) +{ + if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x); + else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x); + else return bf8_raw_t{0}; +} + +CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); + // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + constexpr bool negative_zero_nan = true; + return impl::cast_from_f8(bit_cast(x)); +#endif +} + +CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + float fval; + uint32_t i32val = static_cast(x); + fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); + // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); + return fval; +#else + constexpr bool negative_zero_nan = true; + return impl::cast_from_f8(bit_cast(x)); +#endif +} + +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant = {}) +{ + return bit_cast(float_to_fp8_raw(x, constant{})); +} + +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant = {}) +{ + return bit_cast(float_to_bf8_raw(x, constant{})); +} + +CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) +{ + return fp8_to_float_raw(bit_cast(x)); +} + +CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) +{ + return bf8_to_float_raw(bit_cast(x)); +} + +// clang-format on + +template +struct numeric_traits; + +template <> +struct numeric_traits +{ + static constexpr int exp = 4; + static constexpr int mant = 3; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 8; +#else + static constexpr int bias = 7; +#endif +}; + +template <> +struct numeric_traits +{ + static constexpr int exp = 5; + static constexpr int mant = 2; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 16; +#else + static constexpr int bias = 15; // IEEE +#endif +}; + +template +struct numeric; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr fp8_t min() + { + return bit_cast(static_cast(0x08)); + } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() + { + return bit_cast(static_cast(0xff)); + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr fp8_t max() + { + return bit_cast(static_cast(0x7f)); + } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() + { + return bit_cast(static_cast(0x20)); + } + + // maximum rounding error + // bin : 7 6543 210 + // bits: s eeee mmm + // 0 0110 000 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() + { + return bit_cast(static_cast(0x30)); + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() + { + return bit_cast(static_cast(0x80)); + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() + { + return bit_cast(static_cast(0x80)); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() + { + return bit_cast(static_cast(0x80)); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() + { + return bit_cast(static_cast(0x01)); + } + + CK_TILE_HOST_DEVICE static constexpr fp8_t zero() + { + return bit_cast(static_cast(0)); + } +}; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr bf8_t min() + { + return bit_cast(static_cast(0x04)); + } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() + { + return bit_cast(static_cast(0xff)); + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr bf8_t max() + { + return bit_cast(static_cast(0x7f)); + } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() + { + return bit_cast(static_cast(0x34)); + } + + // maximum rounding error + // bin : 7 65432 10 + // bits: s eeeee mm + // 0 01110 00 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() + { + return bit_cast(static_cast(0x38)); + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() + { + return bit_cast(static_cast(0x80)); + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() + { + return bit_cast(static_cast(0x80)); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() + { + return bit_cast(static_cast(0x80)); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() + { + return bit_cast(static_cast(0x01)); + } + + CK_TILE_HOST_DEVICE static constexpr bf8_t zero() + { + return bit_cast(static_cast(0)); + } +}; + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t) +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t) +#endif + +// math +CK_TILE_HOST_DEVICE +fp8_t abs(const fp8_t& x) +{ + return bit_cast(static_cast(bit_cast(x) & 0x7f)); +} + +CK_TILE_HOST_DEVICE +bool isnan(const fp8_t& x) +{ + uint8_t xx = bit_cast(x); + return xx == 0x80; // TODO: NANOO +} + +CK_TILE_DEVICE +fp8_t sqrt(fp8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; + +CK_TILE_DEVICE +fp8_t exp(fp8_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +fp8_t exp2(fp8_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +fp8_t log(fp8_t x) { return static_cast(__logf(static_cast(x))); }; + +CK_TILE_HOST_DEVICE +bf8_t abs(const bf8_t& x) +{ + return bit_cast(static_cast(bit_cast(x) & 0x7f)); +} + +CK_TILE_HOST_DEVICE +bool isnan(const bf8_t& x) +{ + uint8_t xx = bit_cast(x); + return xx == 0x80; // TODO: NANOO +} + +CK_TILE_DEVICE +bf8_t sqrt(bf8_t x) { return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); }; + +CK_TILE_DEVICE +bf8_t exp(bf8_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +bf8_t exp2(bf8_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +bf8_t log(bf8_t x) { return static_cast(__logf(static_cast(x))); }; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp new file mode 100644 index 000000000..c616b6939 --- /dev/null +++ b/include/ck_tile/core/numeric/half.hpp @@ -0,0 +1,385 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/numeric/numeric.hpp" +#include + +#pragma once + +namespace ck_tile { + +using fp16_hip_t = _Float16; // most of hip internal function use this type +using fp16_raw_t = uint16_t; + +CK_TILE_HOST_DEVICE +constexpr float fp16_to_float_hip(const fp16_hip_t& x); + +CK_TILE_HOST_DEVICE +constexpr double fp16_to_double_hip(const fp16_hip_t& x); + +CK_TILE_HOST_DEVICE +constexpr fp16_hip_t float_to_fp16_hip(const float& x); + +CK_TILE_HOST_DEVICE +constexpr fp16_hip_t double_to_fp16_hip(const double& x); + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +// HIP use fp16_hip_t as interchangable data type for float16 +struct alignas(2) half_t +{ + using raw_type = fp16_raw_t; + raw_type data; + + CK_TILE_HOST_DEVICE + static constexpr half_t bit_cast(raw_type x) + { + half_t y; + y.data = x; + return y; + } + + CK_TILE_HOST_DEVICE + constexpr fp16_hip_t to_fp16() const { return ck_tile::bit_cast(data); } + + // constructor + constexpr half_t() : data{} {} + + // construct from HIP half + CK_TILE_HOST_DEVICE + explicit constexpr half_t(const fp16_hip_t& x) : data(ck_tile::bit_cast(x)) {} + + // construct from float + CK_TILE_HOST_DEVICE + explicit constexpr half_t(const float& x) : half_t(float_to_fp16_hip(x)) {} + + // construct from double + CK_TILE_HOST_DEVICE + explicit constexpr half_t(const double& x) : half_t(double_to_fp16_hip(x)) {} + + // construct from int + CK_TILE_HOST_DEVICE + explicit constexpr half_t(const int& x) : half_t(static_cast(__int2half_rn(x))) {} + + // construct from unsigned int + CK_TILE_HOST_DEVICE + explicit constexpr half_t(const unsigned int& x) + : half_t(static_cast(__uint2half_rn(x))) + { + } + + // cast to float + CK_TILE_HOST_DEVICE + explicit constexpr operator float() const { return fp16_to_float_hip(to_fp16()); } + + // cast to double + CK_TILE_HOST_DEVICE + explicit constexpr operator double() const { return fp16_to_double_hip(to_fp16()); } + + // cast to int + CK_TILE_HOST_DEVICE + explicit constexpr operator int() const + { + return static_cast(fp16_to_float_hip(to_fp16())); + } + + CK_TILE_HOST_DEVICE + explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast(data); } + + // internal access + CK_TILE_HOST_DEVICE + constexpr raw_type& get() { return data; } + + CK_TILE_HOST_DEVICE + constexpr raw_type get() const { return data; } +}; + +template +struct native_t; + +template <> +struct native_t +{ + using type = _Float16; +}; + +using fp16_t = half_t; +using fp16_raw_t = typename half_t::raw_type; +#else +using fp16_t = _Float16; +using half_t = _Float16; +using fp16_raw_t = ushort; +#endif + +// conversions +CK_TILE_HOST_DEVICE +constexpr float fp16_to_float_hip(const fp16_hip_t& x) +{ + // return __half2float(x); + return static_cast(x); +} + +CK_TILE_HOST_DEVICE +constexpr double fp16_to_double_hip(const fp16_hip_t& x) +{ + return static_cast(fp16_to_float_hip(x)); +} + +CK_TILE_HOST_DEVICE +constexpr fp16_hip_t float_to_fp16_hip(const float& x) +{ + return __float2half(x); + // return static_cast(x); +} + +CK_TILE_HOST_DEVICE +constexpr fp16_hip_t double_to_fp16_hip(const double& x) +{ + // return __float2half(x); + return static_cast(x); +} + +CK_TILE_HOST_DEVICE +constexpr float fp16_to_float(const half_t& x) { return static_cast(x); } + +CK_TILE_HOST_DEVICE +constexpr float fp16_to_double(const half_t& x) { return static_cast(x); } + +CK_TILE_HOST_DEVICE +constexpr half_t float_to_fp16(const float& x) { return static_cast(x); } + +CK_TILE_HOST_DEVICE +constexpr half_t double_to_fp16(const double& x) { return static_cast(x); } + +// limits +template +struct numeric; + +template <> +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr half_t min() + { + return bit_cast(static_cast(0x0400)); + } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr half_t lowest() + { + return bit_cast(static_cast(0xFBFF)); + } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr half_t max() + { + return bit_cast(static_cast(0x7BFF)); + } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr half_t epsilon() + { + return bit_cast(static_cast(0x1800)); + } + + // maximum rounding error + // bin : f edcba 9876543210 + // bits: s eeeee mmmmmmmmmm + // 0 01110 0000000000 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr half_t round_error() + { + return bit_cast(static_cast(0x3800)); + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr half_t infinity() + { + return bit_cast(static_cast(0x7C00)); + } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() + { + return bit_cast(static_cast(0x7FFF)); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() + { + return bit_cast(static_cast(0x7FFF)); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() + { + return bit_cast(static_cast(0x0001)); + } + + CK_TILE_HOST_DEVICE static constexpr half_t zero() + { + return bit_cast(static_cast(0)); + } +}; + +template +struct numeric_traits; + +template <> +struct numeric_traits +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + using bitwise_type = uint16_t; +}; + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +// arithmetic +CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y) +{ + return __heq(x.to_fp16(), y.to_fp16()); +} + +CK_TILE_DEVICE +bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); } + +CK_TILE_DEVICE +bool operator<(const half_t& x, const half_t& y) { return __hlt(x.to_fp16(), y.to_fp16()); } + +CK_TILE_DEVICE +bool operator<=(const half_t& x, const half_t& y) { return __hle(x.to_fp16(), y.to_fp16()); } + +CK_TILE_DEVICE +bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.to_fp16()); } + +CK_TILE_DEVICE +bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); } + +#if 0 +CK_TILE_DEVICE +half_t operator+(const half_t& x, const half_t& y) +{ + return half_t(__hadd(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_DEVICE +half_t operator-(const half_t& x) { return half_t(__hneg(x.to_fp16())); } + +CK_TILE_DEVICE +half_t operator-(const half_t& x, const half_t& y) +{ + return half_t(__hsub(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_DEVICE +half_t operator*(const half_t& x, const half_t& y) +{ + return half_t(__hmul(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_DEVICE +half_t operator/(const half_t& x, const half_t& y) +{ + return half_t(__hdiv(x.to_fp16(), y.to_fp16())); +} + +CK_TILE_DEVICE +half_t& operator+=(half_t& x, const half_t& y) +{ + x = half_t(__hadd(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_DEVICE +half_t& operator-=(half_t& x, const half_t& y) +{ + x = half_t(__hsub(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_DEVICE +half_t& operator*=(half_t& x, const half_t& y) +{ + x = half_t(__hmul(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_DEVICE +half_t& operator/=(half_t& x, const half_t& y) +{ + x = half_t(__hdiv(x.to_fp16(), y.to_fp16())); + return x; +} + +CK_TILE_DEVICE +half_t& operator++(half_t& x) +{ + x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16())); + return x; +} + +CK_TILE_DEVICE +half_t& operator--(half_t& x) +{ + x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16())); + return x; +} + +CK_TILE_DEVICE +half_t operator++(half_t& x, int) +{ + half_t y(x); + x = half_t(__hadd(x.to_fp16(), half_t(1.0f).to_fp16())); + return y; +} + +CK_TILE_DEVICE +half_t operator--(half_t& x, int) +{ + half_t y(x); + x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16())); + return y; +} +#endif + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t) +#endif + +// math +CK_TILE_HOST_DEVICE +half_t abs(const half_t& x) { return bit_cast(x.get() & 0x7fff); } + +CK_TILE_HOST_DEVICE +bool isnan(const half_t& x) +{ + uint16_t xx = x.get(); + return (xx & 0x7FFF) > 0x7C00; +} + +CK_TILE_DEVICE +half_t sqrt(half_t x) +{ + return static_cast(__builtin_amdgcn_sqrtf(static_cast(x))); +}; + +CK_TILE_DEVICE +half_t exp(half_t x) { return static_cast(__expf(static_cast(x))); }; + +CK_TILE_DEVICE +half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))); }; + +CK_TILE_DEVICE +half_t log(half_t x) { return static_cast(__logf(static_cast(x))); }; +#endif +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/integer.hpp b/include/ck_tile/core/numeric/integer.hpp new file mode 100644 index 000000000..3faf3020a --- /dev/null +++ b/include/ck_tile/core/numeric/integer.hpp @@ -0,0 +1,13 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +namespace ck_tile { + +using index_t = int32_t; +using long_index_t = int64_t; +using int8_t = int8_t; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/integral_constant.hpp b/include/ck_tile/core/numeric/integral_constant.hpp new file mode 100644 index 000000000..ea7a67abc --- /dev/null +++ b/include/ck_tile/core/numeric/integral_constant.hpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" + +namespace ck_tile { + +template +struct constant +{ + using value_type = decltype(v); + using type = constant; // using injected-class-name + static constexpr value_type value = v; + CK_TILE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CK_TILE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } +}; + +template +struct integral_constant : constant +{ + using value_type = T; + using type = integral_constant; // using injected-class-name + static constexpr T value = v; + // constexpr CK_TILE_HOST_DEVICE operator value_type() const noexcept { return value; } + // constexpr CK_TILE_HOST_DEVICE value_type operator()() const noexcept { return value; } // +}; + +template +using number = constant; + +template +using long_number = constant; + +template +using bool_constant = constant; + +#define CK_TILE_LEFT_UNARY_OP(OP) \ + template \ + CK_TILE_HOST_DEVICE constexpr auto operator OP(constant) \ + { \ + return constant<(OP x)>{}; \ + } + +#define CK_TILE_BINARY_OP(OP) \ + template \ + CK_TILE_HOST_DEVICE constexpr auto operator OP(constant, constant) \ + { \ + return constant<(x OP y)>{}; \ + } + +CK_TILE_LEFT_UNARY_OP(+) +CK_TILE_LEFT_UNARY_OP(-) +CK_TILE_LEFT_UNARY_OP(~) +CK_TILE_LEFT_UNARY_OP(!) +CK_TILE_LEFT_UNARY_OP(*) + +CK_TILE_BINARY_OP(+) +CK_TILE_BINARY_OP(-) +CK_TILE_BINARY_OP(*) +CK_TILE_BINARY_OP(/) +CK_TILE_BINARY_OP(%) +CK_TILE_BINARY_OP(&) +CK_TILE_BINARY_OP(|) +CK_TILE_BINARY_OP(^) +CK_TILE_BINARY_OP(<<) +CK_TILE_BINARY_OP(>>) +CK_TILE_BINARY_OP(&&) +CK_TILE_BINARY_OP(||) +CK_TILE_BINARY_OP(==) +CK_TILE_BINARY_OP(!=) +CK_TILE_BINARY_OP(>) +CK_TILE_BINARY_OP(<) +CK_TILE_BINARY_OP(>=) +CK_TILE_BINARY_OP(<=) + +#undef CK_TILE_LEFT_UNARY_OP +#undef CK_TILE_BINARY_OP + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp new file mode 100644 index 000000000..72ec607b4 --- /dev/null +++ b/include/ck_tile/core/numeric/math.hpp @@ -0,0 +1,539 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include +#include +#include + +namespace ck_tile { + +template +struct scales_c +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs) + { + return lhs * rhs; + } +}; + +template +struct scales +{ + static_assert(std::is_copy_constructible_v); + + CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {} + + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const + -> decltype(std::declval() * rhs) + { + return lhs_ * rhs; + } + + private: + Scale lhs_; +}; + +/// FIXME: create macro to replace '__host__ __device__' and nothing more +template +__host__ __device__ scales(Scale)->scales; + +template +struct plus +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs + rhs) + { + return lhs + rhs; + } +}; + +template <> +struct plus +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs + rhs) + { + return lhs + rhs; + } +}; + +/// FIXME: create macro to replace '__host__ __device__' and nothing more +__host__ __device__ plus()->plus; + +template +struct minus +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs - rhs) + { + return lhs - rhs; + } +}; + +template <> +struct minus +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs - rhs) + { + return lhs - rhs; + } +}; + +/// FIXME: create macro to replace '__host__ __device__' and nothing more +__host__ __device__ minus()->minus; + +template +struct multiplies +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs * rhs) + { + return lhs * rhs; + } +}; + +template <> +struct multiplies +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs * rhs) + { + return lhs * rhs; + } +}; + +/// FIXME: create macro to replace '__host__ __device__' and nothing more +__host__ __device__ multiplies()->multiplies; + +template +struct maximize +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a >= b ? a : b; } +}; + +template +struct minimize +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const { return a <= b ? a : b; } +}; + +template +struct integer_divide_ceiler +{ + CK_TILE_HOST_DEVICE constexpr T operator()(T a, T b) const + { + static_assert(std::is_same{} || std::is_same{}, "wrong type"); + return (a + b - number<1>{}) / b; + } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto integer_divide_floor(X x, Y y) +{ + return x / y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y) +{ + return (x + y - number<1>{}) / y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y) +{ + return y * integer_divide_ceil(x, y); +} + +template +CK_TILE_HOST_DEVICE constexpr T max(T x) +{ + return x; +} + +template +CK_TILE_HOST constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template +CK_TILE_DEVICE constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template <> +CK_TILE_DEVICE constexpr float max(float x, float y) +{ + return __builtin_fmaxf(x, y); // can resultin v_max3_f32 +} + +template <> +CK_TILE_DEVICE constexpr double max(double x, double y) +{ + return __builtin_fmax(x, y); // maybe still v_max3_f32 +} + +template +CK_TILE_HOST_DEVICE constexpr index_t max(number, index_t y) +{ + return X > y ? X : y; +} + +template +CK_TILE_HOST_DEVICE constexpr index_t max(index_t x, number) +{ + return x > Y ? x : Y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto max(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + return max(x, max(ys...)); +} + +template +CK_TILE_HOST_DEVICE constexpr T min(T x) +{ + return x; +} + +template +CK_TILE_HOST constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template +CK_TILE_DEVICE constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template <> +CK_TILE_DEVICE constexpr float min(float x, float y) +{ + return __builtin_fminf(x, y); +} + +template <> +CK_TILE_DEVICE constexpr double min(double x, double y) +{ + return __builtin_fmin(x, y); +} + +template +CK_TILE_HOST_DEVICE constexpr index_t min(number, index_t y) +{ + return X < y ? X : y; +} + +template +CK_TILE_HOST_DEVICE constexpr index_t min(index_t x, number) +{ + return x < Y ? x : Y; +} + +template +CK_TILE_HOST_DEVICE constexpr auto min(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + return min(x, min(ys...)); +} + +template +CK_TILE_HOST_DEVICE constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound) +{ + return min(max(x, lowerbound), upperbound); +} + +CK_TILE_HOST int clz(uint32_t x) { return __builtin_clz(x); } +CK_TILE_DEVICE int clz(uint32_t x) { return __clz(x); } + +// greatest common divisor, aka highest common factor +CK_TILE_HOST_DEVICE constexpr index_t gcd(index_t x, index_t y) +{ + if(x < 0) + { + return gcd(-x, y); + } + else if(y < 0) + { + return gcd(x, -y); + } + else if(x == y || x == 0) + { + return y; + } + else if(y == 0) + { + return x; + } + else if(x > y) + { + return gcd(x % y, y); + } + else + { + return gcd(x, y % x); + } +} + +template +CK_TILE_HOST_DEVICE constexpr auto gcd(number, number) +{ + constexpr auto r = gcd(X, Y); + + return number{}; +} + +template = 2, bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto gcd(X x, Ys... ys) +{ + return gcd(x, gcd(ys...)); +} + +// least common multiple +template +CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Y y) +{ + return (x * y) / gcd(x, y); +} + +template = 2, bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys) +{ + return lcm(x, lcm(ys...)); +} + +template +struct equal +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs == rhs) + { + return lhs == rhs; + } +}; + +template <> +struct equal +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs == rhs) + { + return lhs == rhs; + } +}; + +/// FIXME: create macro to replace '__host__ __device__' and nothing more +__host__ __device__ equal()->equal; + +template <> +struct equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const + { + return bit_cast(lhs) == bit_cast(rhs); + } +}; + +template <> +struct equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const + { + return bit_cast(lhs) == bit_cast(rhs); + } +}; + +template +struct less +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs < rhs) + { + return lhs < rhs; + } +}; + +template <> +struct less +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs < rhs) + { + return lhs < rhs; + } +}; + +/// FIXME: create macro to replace '__host__ __device__' and nothing more +__host__ __device__ less()->less; + +template +struct less_equal +{ + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs <= rhs) + { + return lhs <= rhs; + } +}; + +template <> +struct less_equal +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const + -> decltype(lhs <= rhs) + { + return lhs <= rhs; + } +}; + +/// FIXME: create macro to replace '__host__ __device__' and nothing more +__host__ __device__ less_equal()->less_equal; + +template <> +struct less_equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(float lhs, float rhs) const + { + return lhs < rhs || bit_cast(lhs) == bit_cast(rhs); + } +}; + +template <> +struct less_equal +{ + CK_TILE_HOST_DEVICE constexpr bool operator()(double lhs, double rhs) const + { + return lhs < rhs || bit_cast(lhs) == bit_cast(rhs); + } +}; + +CK_TILE_HOST_DEVICE constexpr int32_t next_power_of_two(int32_t x) +{ + // TODO: x need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail + return 1 << (32 - clz(x - 1)); +} + +template +CK_TILE_HOST_DEVICE constexpr auto next_power_of_two() +{ + constexpr index_t y = next_power_of_two(X); + return number{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto next_power_of_two(number) +{ + constexpr index_t y = next_power_of_two(X); + return number{}; +} + +CK_TILE_HOST_DEVICE constexpr int32_t integer_log2_floor(int32_t x) +{ + // TODO: x need to be 1 ~ 0x7fffffff + // __builtin_clz will produce unexpected result if x is 0; + return 31 - __builtin_clz(x); +} + +CK_TILE_HOST_DEVICE constexpr bool is_power_of_two_integer(int32_t x) +{ + // TODO: x need to be 1 ~ 0x7fffffff + return x == (1 << integer_log2_floor(x)); +} + +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + +template +struct log2e; + +template <> +struct log2e +{ + static constexpr double value = C_LOG2E; +}; + +template <> +struct log2e +{ + static constexpr float value = C_LOG2E; +}; + +template +constexpr T log2e_v = log2e::value; + +// math +CK_TILE_HOST_DEVICE +float abs(const float& x) +{ + union + { + float f32; + uint32_t u32; + } y; + y.f32 = x; + y.u32 = y.u32 & 0x7fffffff; + return y.f32; +} + +CK_TILE_HOST_DEVICE +bool isnan(const float& x) +{ + uint32_t xx = bit_cast(x); + return (xx & 0x7fffffff) > 0x7F800000; +} + +CK_TILE_HOST float sqrt(float x) { return std::sqrt(x); }; + +CK_TILE_HOST double sqrt(double x) { return std::sqrt(x); }; + +CK_TILE_DEVICE +float sqrt(float x) { return __builtin_amdgcn_sqrtf(x); }; + +CK_TILE_DEVICE +double sqrt(double x) { return __builtin_amdgcn_sqrt(x); }; + +CK_TILE_DEVICE +float exp(float x) { return __expf(x); }; + +CK_TILE_HOST +float exp(float x) { return std::expf(x); } + +CK_TILE_DEVICE +float exp2(float x) { return exp2f(x); }; + +CK_TILE_HOST +float exp2(float x) { return std::exp2f(x); }; + +CK_TILE_DEVICE +float log(float x) { return __logf(x); }; + +CK_TILE_HOST +float log(float x) { return std::logf(x); }; + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp new file mode 100644 index 000000000..35745b12d --- /dev/null +++ b/include/ck_tile/core/numeric/numeric.hpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include +#include + +namespace ck_tile { + +// this struct has the information of +// 1. limit of a certain type, simliar to std::numeric_limits +// 2. some pre-defined value, zero, one... +// +template +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits::min(); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits::lowest(); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits::max(); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits::epsilon(); } + + // maximum rounding error + CK_TILE_HOST_DEVICE static constexpr T round_error() + { + return std::numeric_limits::round_error(); + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits::infinity(); } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr T quiet_NaN() + { + return std::numeric_limits::quiet_NaN(); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr T signaling_NaN() + { + return std::numeric_limits::signaling_NaN(); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr T denorm_min() + { + return std::numeric_limits::denorm_min(); + } + + CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast(0); } + + CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast(1); } + +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + + CK_TILE_HOST_DEVICE static constexpr T log2e() + { + if constexpr(std::is_same_v || std::is_same_v) + { + return static_cast(C_LOG2E); + } + else + { + return 0; // TODO: integer? + } + } +}; + +template +struct numeric_traits; + +template <> +struct numeric_traits +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; +}; + +} // namespace ck_tile + +#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \ + attr_ bool operator==(const type_& x, const type_& y) \ + { \ + return static_cast(x) == static_cast(y); \ + } \ + attr_ bool operator!=(const type_& x, const type_& y) \ + { \ + return static_cast(x) != static_cast(y); \ + } \ + attr_ bool operator<(const type_& x, const type_& y) \ + { \ + return static_cast(x) < static_cast(y); \ + } \ + attr_ bool operator<=(const type_& x, const type_& y) \ + { \ + return static_cast(x) <= static_cast(y); \ + } \ + attr_ bool operator>(const type_& x, const type_& y) \ + { \ + return static_cast(x) > static_cast(y); \ + } \ + attr_ bool operator>=(const type_& x, const type_& y) \ + { \ + return static_cast(x) >= static_cast(y); \ + } \ + attr_ type_ operator+(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) + static_cast(y)); \ + } \ + attr_ type_ operator-(const type_& x) \ + { \ + constexpr uint32_t bits = sizeof(type_) * 8; \ + constexpr uint32_t mask = 1 << (bits - 1); \ + type_ y = x; \ + y.data ^= static_cast(mask); \ + return y; \ + } \ + attr_ type_ operator-(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) - static_cast(y)); \ + } \ + attr_ type_ operator*(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) * static_cast(y)); \ + } \ + attr_ type_ operator/(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) / static_cast(y)); \ + } \ + attr_ type_& operator+=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) + static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator-=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) - static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator*=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) * static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator/=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) / static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator++(type_& x) \ + { \ + x = type_(static_cast(x) + 1.f); \ + return x; \ + } \ + attr_ type_& operator--(type_& x) \ + { \ + x = type_(static_cast(x) - 1.f); \ + return x; \ + } \ + attr_ type_ operator++(type_& x, int) \ + { \ + type_ y(x); \ + x = type_(static_cast(x) + 1.f); \ + return y; \ + } \ + attr_ type_ operator--(type_& x, int) \ + { \ + type_ y(x); \ + x = type_(static_cast(x) - 1.f); \ + return y; \ + } diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp new file mode 100644 index 000000000..cb18cde70 --- /dev/null +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/float8.hpp" + +namespace ck_tile { + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +template +CK_TILE_HOST_DEVICE constexpr remove_cvref_t type_convert(const X& x) +{ + return static_cast(x); +} +#else +// Convert X to Y, both X and Y are non-const data types. +template || std::is_const_v), bool> = false> +CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + return static_cast(x); +} + +// Convert X to Y, either X or Y is a const data type. +template || std::is_const_v, bool> = false> +CK_TILE_HOST_DEVICE constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + using non_const_y = std::remove_const_t; + using non_const_x = std::remove_const_t; + return static_cast(type_convert(x)); +} + +#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \ + template <> \ + CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ + { \ + return sname_##_to_##dname_(x); \ + } + +CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16) +CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16) +CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8) +CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8) + +CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float) +CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float) +CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float) +CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) + +#undef CK_TILE_TYPE_CONVERT +#endif + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp new file mode 100644 index 000000000..85d9be1c9 --- /dev/null +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// this structure is used to pick up the type inside +// using xxx = __attribute__((ext_vector_type(N))); +// because clang only allow native type + bool in this term (custom type will fail) +// overload this structure to let proper type + +template +struct native_t +{ + using type = remove_cvref_t; +}; + +// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay +// basic type to construct a ext_vector_type you must be very careful using this, or will have lot +// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will +// have compiler error +namespace impl { +template +struct ext_vector +{ + static constexpr index_t N = N_; + using value_type = typename native_t>::type; + static_assert(!std::is_class_v); + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous +}; + +template +struct ext_vector +{ + static constexpr index_t N = Vs_ * N_; + using value_type = typename native_t>::type; + static_assert(!std::is_class_v); + using type = value_type __attribute__((ext_vector_type(N))); // this is danguous +}; + +} // namespace impl + +template +using ext_vector_t = typename impl::ext_vector::type; + +// by default, any type will result in a vector_size=1 with scalar_type=T traits. +// ... unless we have other vector_traits specialization +template +struct vector_traits +{ + using scalar_type = remove_cvref_t; + static constexpr index_t vector_size = 1; +}; + +// specialization for ext_vector_type() +template +struct vector_traits +{ + using scalar_type = T; + static constexpr index_t vector_size = N; +}; + +template +using has_same_scalar_type = std::is_same>::scalar_type, + typename vector_traits>::scalar_type>; + +// below are some pre-defines of ext_vector_type +// attention! 2 vector type could be just the same type +// fp64 +using fp64_t = double; +using fp64x2_t = double __attribute__((ext_vector_type(2))); +using fp64x4_t = double __attribute__((ext_vector_type(4))); + +// fp32 +using fp32_t = float; +using fp32x2_t = float __attribute__((ext_vector_type(2))); +using fp32x4_t = float __attribute__((ext_vector_type(4))); +using fp32x8_t = float __attribute__((ext_vector_type(8))); +using fp32x16_t = float __attribute__((ext_vector_type(16))); +using fp32x32_t = float __attribute__((ext_vector_type(32))); +using fp32x64_t = float __attribute__((ext_vector_type(64))); + +// fp16 +// using fp16_t = ... +using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); +using fp16x4_t = _Float16 __attribute__((ext_vector_type(4))); +using fp16x8_t = _Float16 __attribute__((ext_vector_type(8))); +using fp16x16_t = _Float16 __attribute__((ext_vector_type(16))); +using fp16x32_t = _Float16 __attribute__((ext_vector_type(32))); +using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); + +// bf16 +// using bf16_t = ... +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); +using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); +using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); +using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16))); +using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32))); +using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); + +// i32 +// using int32_t = ... +using int32x2_t = int32_t __attribute__((ext_vector_type(2))); +using int32x4_t = int32_t __attribute__((ext_vector_type(4))); +using int32x8_t = int32_t __attribute__((ext_vector_type(8))); +using int32x16_t = int32_t __attribute__((ext_vector_type(16))); +using int32x32_t = int32_t __attribute__((ext_vector_type(32))); +using int32x64_t = int32_t __attribute__((ext_vector_type(64))); + +// i16 +// using int16_t = ... +using int16x2_t = int16_t __attribute__((ext_vector_type(2))); +using int16x4_t = int16_t __attribute__((ext_vector_type(4))); +using int16x8_t = int16_t __attribute__((ext_vector_type(8))); +using int16x16_t = int16_t __attribute__((ext_vector_type(16))); +using int16x32_t = int16_t __attribute__((ext_vector_type(32))); +using int16x64_t = int16_t __attribute__((ext_vector_type(64))); + +// u16 +// using uint16_t +using uint16x2_t = uint16_t __attribute__((ext_vector_type(2))); +using uint16x4_t = uint16_t __attribute__((ext_vector_type(4))); +using uint16x8_t = uint16_t __attribute__((ext_vector_type(8))); +using uint16x16_t = uint16_t __attribute__((ext_vector_type(16))); +using uint16x32_t = uint16_t __attribute__((ext_vector_type(32))); +using uint16x64_t = uint16_t __attribute__((ext_vector_type(64))); + +// i8 +// using int8_t +using int8x2_t = int8_t __attribute((ext_vector_type(2))); +using int8x4_t = int8_t __attribute((ext_vector_type(4))); +using int8x8_t = int8_t __attribute((ext_vector_type(8))); +using int8x16_t = int8_t __attribute((ext_vector_type(16))); +using int8x32_t = int8_t __attribute((ext_vector_type(32))); +using int8x64_t = int8_t __attribute((ext_vector_type(64))); + +#if CK_TILE_USE_CUSTOM_DATA_TYPE +// f8 +// using fp8_t +using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2))); +using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4))); +using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8))); +using fp8x16_t = fp8_raw_t __attribute((ext_vector_type(16))); +using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32))); +using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64))); + +// bf8 +// using bf8_t +using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2))); +using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4))); +using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8))); +using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16))); +using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32))); +using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64))); +#else +// f8 +// using fp8_t +using fp8x2_t = fp8_t __attribute((ext_vector_type(2))); +using fp8x4_t = fp8_t __attribute((ext_vector_type(4))); +using fp8x8_t = fp8_t __attribute((ext_vector_type(8))); +using fp8x16_t = fp8_t __attribute((ext_vector_type(16))); +using fp8x32_t = fp8_t __attribute((ext_vector_type(32))); +using fp8x64_t = fp8_t __attribute((ext_vector_type(64))); + +// bf8 +// using bf8_t +using bf8x2_t = bf8_t __attribute((ext_vector_type(2))); +using bf8x4_t = bf8_t __attribute((ext_vector_type(4))); +using bf8x8_t = bf8_t __attribute((ext_vector_type(8))); +using bf8x16_t = bf8_t __attribute((ext_vector_type(16))); +using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); +using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); +#endif + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp new file mode 100644 index 000000000..96b38241c --- /dev/null +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -0,0 +1,1068 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/amd_buffer_addressing.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/float8.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split +// buffer_view definition for different memory address space (Global/GenericLds/Vgpr) +template +struct buffer_view; + +// Address Space: generic +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::generic; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{numeric>::zero()}; + } + else + { + return X{invalid_element_value_}; + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == memory_operation_enum::set) + { + this->template set(i, is_valid_element, x); + } + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: generic, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +// Address Space: Global +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::global; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_TILE_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_zero, + t_per_x, + Coherence, + oob_conditional_check>( + p_data_, i, is_valid_element, buffer_size_); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value< + remove_cvref_t, + t_per_x, + Coherence, + oob_conditional_check>( + p_data_, i, is_valid_element, buffer_size_, invalid_element_value_); + } + } + else + { + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{numeric>::zero()}; + } + else + { + return X{invalid_element_value_}; + } + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get_raw(remove_cvref_t& dst, index_t i, bool is_valid_element) const + { + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_load_raw, t_per_x, Coherence, oob_conditional_check>( + dst, p_data_, i, buffer_size_, is_valid_element); + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + async_get(remove_cvref_t* smem, index_t i, bool /*is_valid_element*/) const + { + // X is vector of T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_async_buffer_load_with_oob, t_per_x, Coherence>( + smem, p_data_, i, buffer_size_); + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == memory_operation_enum::set) + { + this->template set(i, is_valid_element, x); + } + else if constexpr(Op == memory_operation_enum::atomic_add) + { + this->template atomic_add(i, is_valid_element, x); + } + else if constexpr(Op == memory_operation_enum::atomic_max) + { + this->template atomic_max(i, is_valid_element, x); + } + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + // tmp += x; + // this->template set(i, is_valid_element, tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_TILE_USE_AMD_BUFFER_STORE + bool constexpr use_amd_buffer_addressing = true; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store, t_per_x, Coherence>( + x, p_data_, i, is_valid_element, buffer_size_); + } + else + { + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void set_raw(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + amd_buffer_store_raw, t_per_x, Coherence, oob_conditional_check>( + x, p_data_, i, is_valid_element, buffer_size_); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void atomic_add(index_t i, bool is_valid_element, const X& x) + { + using scalar_t = typename vector_traits>::scalar_type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(get_address_space() == address_space_enum::global, "only support global mem"); + +#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + std::is_same_v, int32_t> || + std::is_same_v, float> || + (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) + bool constexpr use_amd_buffer_addressing = + std::is_same_v, int32_t>; +#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + std::is_same_v, float> || + (std::is_same_v, half_t> && scalar_per_x_vector % 2 == 0); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, buffer_size_); + } + else + { + if(is_valid_element) + { + atomic_add(c_style_pointer_cast(&p_data_[i]), x); + } + } + } + + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void atomic_max(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(get_address_space() == address_space_enum::global, "only support global mem"); + +#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 + using scalar_t = typename vector_traits>::scalar_type; + bool constexpr use_amd_buffer_addressing = std::is_same_v, double>; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_max, t_per_x>( + x, p_data_, i, is_valid_element, buffer_size_); + } + else if(is_valid_element) + { + atomic_max(c_style_pointer_cast(&p_data_[i]), x); + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: Global, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +// Address Space: LDS +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::lds; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + // using buf_t = ushort __attribute__((ext_vector_type(8))); + auto rtn = *c_style_pointer_cast(&p_data_[i]); + return bit_cast(rtn); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{numeric>::zero()}; + } + else + { + return X{invalid_element_value_}; + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == memory_operation_enum::set) + { + this->template set(i, is_valid_element, x); + } + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + bool constexpr workaround_int8_ds_write_issue = true; +#else + bool constexpr workaround_int8_ds_write_issue = false; +#endif + + if constexpr(std::is_same>::scalar_type, + int8_t>::value && + workaround_int8_ds_write_issue) + { + if(is_valid_element) + { + // HACK: compiler would lower IR "store address_space(3)" into inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + static_assert((std::is_same, int8_t>::value && + std::is_same, int8_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x2_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x4_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x8_t>::value) || + (std::is_same, int8_t>::value && + std::is_same, int8x16_t>::value) || + (std::is_same, int8x4_t>::value && + std::is_same, int8x4_t>::value) || + (std::is_same, int8x8_t>::value && + std::is_same, int8x8_t>::value) || + (std::is_same, int8x16_t>::value && + std::is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(std::is_same, int8_t>::value && + std::is_same, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(std::is_same, int8_t>::value && + std::is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(std::is_same, int8x4_t>::value && + std::is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(std::is_same, int8x8_t>::value && + std::is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(std::is_same, int8x16_t>::value && + std::is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + } + else + { + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + + *c_style_pointer_cast(&p_data_[i]) = reinterpret_cast(x); +#endif + } + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: Lds, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +// Address Space: Vgpr +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of +// transforms of tensor_view/Tensor +template +struct buffer_view +{ + using type = T; + + T* p_data_ = nullptr; + BufferSizeType buffer_size_; + remove_cvref_t invalid_element_value_ = T{0}; + + CK_TILE_HOST_DEVICE constexpr buffer_view() + : p_data_{}, buffer_size_{}, invalid_element_value_{} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0} + { + } + + CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, + BufferSizeType buffer_size, + T invalid_element_value) + : p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value} + { + } + + CK_TILE_DEVICE static constexpr address_space_enum get_address_space() + { + return address_space_enum::vgpr; + } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr const T& operator[](index_t i) const { return p_data_[i]; } + + // i is offset of T + // FIXME: doesn't do is_valid check + CK_TILE_DEVICE constexpr T& operator()(index_t i) { return p_data_[i]; } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE constexpr auto + get(index_t i, bool is_valid_element, bool_constant = {}) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{numeric>::zero()}; + } + else + { + return X{invalid_element_value_}; + } + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void update(index_t i, bool is_valid_element, const X& x) + { + if constexpr(Op == memory_operation_enum::set) + { + this->template set(i, is_valid_element, x); + } + // FIXME: remove memory_operation_enum::add + else if constexpr(Op == memory_operation_enum::add) + { + auto tmp = this->template get(i, is_valid_element); + this->template set(i, is_valid_element, x + tmp); + } + } + + // i is offset of T, not X. i should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>::value, + bool>::type = false> + CK_TILE_DEVICE void set(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = vector_traits>::vector_size; + + constexpr index_t scalar_per_x_vector = vector_traits>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + if(is_valid_element) + { +#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_static_buffer() { return false; } + + // FIXME: remove + CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; } + + CK_TILE_HOST_DEVICE void print() const + { + printf("buffer_view{"); + + // AddressSpace + printf("AddressSpace: Vgpr, "); + + // p_data_ + printf("p_data_: %p, ", static_cast(const_cast*>(p_data_))); + + // buffer_size_ + printf("buffer_size_: "); + print(buffer_size_); + printf(", "); + + // invalid_element_value_ + printf("invalid_element_value_: "); + print(invalid_element_value_); + + printf("}"); + } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size) +{ + return buffer_view{p, buffer_size}; +} + +template , remove_cvref_t>::value, + bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value) +{ + return buffer_view{ + p, buffer_size, invalid_element_value}; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp new file mode 100644 index 000000000..288a60602 --- /dev/null +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/null_tile_window.hpp" +#include "ck_tile/core/tensor/null_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution& tile_window, + bool_constant = {}) +{ + return tile_window.load(bool_constant{}); +} + +template +CK_TILE_DEVICE auto load_tile_raw(T& tile, + const tile_window_with_static_distribution& tile_window, + bool_constant = {}) +{ + tile_window.load_raw(tile, bool_constant{}); +} + +template +CK_TILE_DEVICE auto +async_load_tile_raw(LdsTileWindow_&& lds_tile, + const tile_window_with_static_distribution& tile_window) +{ + return tile_window.async_load(lds_tile); +} + +CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0) +{ + asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory"); +} + +template +CK_TILE_DEVICE auto load_tile(const null_tile_window&) +{ + return null_tensor{}; +} + +template +CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window&) +{ +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/null_tensor.hpp b/include/ck_tile/core/tensor/null_tensor.hpp new file mode 100644 index 000000000..565ff87df --- /dev/null +++ b/include/ck_tile/core/tensor/null_tensor.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +struct null_tensor +{ +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/null_tile_window.hpp b/include/ck_tile/core/tensor/null_tile_window.hpp new file mode 100644 index 000000000..89806203a --- /dev/null +++ b/include/ck_tile/core/tensor/null_tile_window.hpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tensor_view.hpp" + +namespace ck_tile { + +// placeholder type if we want to opt-out a tile window parameter +template +struct null_tile_window +{ + using BottomTensorView = null_tensor_view; + using WindowLengths = remove_cvref_t; + + using BottomTensorIndex = array; + + CK_TILE_DEVICE constexpr null_tile_window() = default; + + CK_TILE_DEVICE constexpr null_tile_window(const WindowLengths& window_lengths) + : window_lengths_{window_lengths} + { + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return null_tensor_view{}; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return BottomTensorIndex{}; } + + WindowLengths window_lengths_; +}; + +// utility to check if this is a Null Tile Window +namespace impl { +template +struct is_null_tile_window : public std::false_type +{ +}; + +template +struct is_null_tile_window> : public std::true_type +{ +}; +} // namespace impl + +template +CK_TILE_DEVICE constexpr auto is_null_tile_window(const T&) +{ + return impl::is_null_tile_window>::value; +} + +template +CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths& window_lengths) +{ + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return null_tile_window>{window_lengths}; +} + +template +CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, + const WindowLengths& window_lengths, + const multi_index& /*origin*/, + Ts&&...) +{ + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return null_tile_window>{window_lengths}; +} + +template +CK_TILE_DEVICE void +move_tile_window(null_tile_window&, + const typename null_tile_window::BottomTensorIndex&) +{ +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp new file mode 100644 index 000000000..baf009add --- /dev/null +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" +#include "ck_tile/core/container/statically_indexed_array.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tile_elementwise.hpp" +#include "ck_tile/core/utility/transpose_vectors.hpp" + +namespace ck_tile { +namespace detail { + +template +CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InTensor& in_tensor) +{ + constexpr auto I0 = number<0>{}; + + using DataType = typename InTensor::DataType; + + constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor(); + constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor(); + + // y_dim_out_to_in + constexpr auto get_rh_major_minor_to_y = [](auto dstr_tensor) { + using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode; + + map, index_t> rh_major_minor_to_y_; + + static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) { + constexpr index_t rh_major = DstrEncode::ys_to_rhs_major_[i]; + constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i]; + + rh_major_minor_to_y_({rh_major, rh_minor}) = i; + }); + + return rh_major_minor_to_y_; + }; + + constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{}); + constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{}); + + constexpr auto y_dim_out_to_in = [&] { + map y_dim_out_to_in_; + + for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out) + { + y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor]; + } + + return y_dim_out_to_in_; + }(); + + // + constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); + + constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); + + // input and output vector dim in the order of input Y dims + constexpr index_t y_dim_vec_in = NDimY - 1; + constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1]; + + // vector lengths + constexpr index_t vec_length_in = y_lengths[y_dim_vec_in]; + constexpr index_t vec_length_out = y_lengths[y_dim_vec_out]; + + // # of vectors + constexpr index_t num_vec_in = vec_length_out; + constexpr index_t num_vec_out = vec_length_in; + + using InVec = array; + using OutVec = array; + + // using InVec = typename InVec::type; + // using OutVec = typename OutVec::type; + + // SFC + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == y_dim_vec_in or i == y_dim_vec_out) ? y_lengths[i] : 1; }, + number{}); + + constexpr auto scalars_per_access = TO_SEQUENCE(scalars_per_access_arr, NDimY); + + using SFC_Y = space_filling_curve::type, + decltype(scalars_per_access)>; + + constexpr index_t num_access = SFC_Y::get_num_of_access(); + + static_assert(num_access > 0, "wrong! num_access should be larger than 0"); + + // in/out vectors to be transposed + thread_buffer in_vectors; + thread_buffer out_vectors; + + // loop over SFC and do transpose + static_for<0, num_access, 1>{}([&](auto iAccess) { + // data index [y0, y1, ...] in the order of input tensor + constexpr auto idx_y_start = SFC_Y::get_index(iAccess); + + // get input vectors + static_for<0, num_vec_in, 1>{}([&](auto i) { + constexpr auto idx_y_in = generate_array( + [&](auto ii) { + return ii == y_dim_vec_out ? idx_y_start[ii] + i : idx_y_start[ii]; + }, + number{}); + + constexpr index_t in_offset = y_in_desc.calculate_offset(idx_y_in); + static_assert(in_offset % vec_length_in == 0); + + in_vectors(i).template get_as()(I0) = + in_tensor.get_thread_buffer() + .template get_as()[number{}]; + }); + + // transpose + transpose_vectors{}(in_vectors, out_vectors); + + // set output vectors + static_for<0, num_vec_out, 1>{}([&](auto i) { + constexpr auto idx_y_out_tmp = generate_array( + [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, + number{}); + + constexpr auto idx_y_out = + container_reorder_given_new2old(idx_y_out_tmp, y_dim_out_to_in); + + constexpr index_t out_offset = y_out_desc.calculate_offset(idx_y_out); + static_assert(out_offset % vec_length_out == 0); + + out_tensor.get_thread_buffer().template set_as( + number{}, + out_vectors[i].template get_as()[I0]); + }); + }); +} + +} // namespace detail + +template +CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in) +{ + using InDataType = typename InTensor::DataType; + using OutDataType = typename OutTensor::DataType; + + using InDstrEncode = typename InTensor::StaticTileDistribution::DstrEncode; + using OutDstrEncode = typename OutTensor::StaticTileDistribution::DstrEncode; + + // type convert + const auto in_tmp = tile_elementwise_in(type_convert, in); + + // shuffle + if constexpr(InDstrEncode::rs_lengths_ == OutDstrEncode::rs_lengths_ && + InDstrEncode::hs_lengthss_ == OutDstrEncode::hs_lengthss_ && + InDstrEncode::ps_to_rhss_major_ == OutDstrEncode::ps_to_rhss_major_ && + InDstrEncode::ps_to_rhss_minor_ == OutDstrEncode::ps_to_rhss_minor_ && + InDstrEncode::NDimY == OutDstrEncode::NDimY) + { + detail::shuffle_tile_impl_in_thread(out, in_tmp); + } + else + { + // NOT implemented + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp new file mode 100644 index 000000000..7a4ba2eb7 --- /dev/null +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE constexpr auto +get_slice_tile(const tile_window_with_static_lengths& tile, + sequence slice_begins, + sequence slice_ends) +{ + using TileWindow = tile_window_with_static_lengths; + // NOTE: This API will override the origin of the tile window! + static_assert(sizeof...(SliceBegins) == sizeof...(SliceEnds)); + static_assert(sizeof...(SliceBegins) == TileWindow::get_num_of_dimension()); + + constexpr auto slice_lengths = slice_ends - slice_begins; + + return make_tile_window(tile.get_bottom_tensor_view(), + sequence_to_tuple_of_number(slice_lengths), + to_multi_index(slice_begins)); +} + +template +CK_TILE_DEVICE constexpr auto +get_slice_tile(const static_distributed_tensor& tile, + sequence slice_begins, + sequence slice_ends) +{ + using DataType = remove_cvref_t; + using Distribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(Distribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); + + auto sliced_tensor = make_static_distributed_tensor(sliced_dstr); + + sliced_tensor.get_thread_buffer() = + tile.get_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths); + + return sliced_tensor; +} + +template +CK_TILE_DEVICE constexpr auto +set_slice_tile(static_distributed_tensor& dst_tile, + const static_distributed_tensor& src_tile, + sequence slice_begins, + sequence slice_ends) +{ + using DstDistribution = remove_cvref_t; + + constexpr auto sliced_dstr_yidx_ylen = + detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); + + constexpr auto sliced_dstr = sliced_dstr_yidx_ylen.template at<0>(); + constexpr auto sliced_y_origins = sliced_dstr_yidx_ylen.template at<1>(); + constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); + + static_assert(std::is_same_v, "wrong!"); + + dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp new file mode 100644 index 000000000..299a74bc0 --- /dev/null +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" + +namespace ck_tile { + +template +struct static_distributed_tensor +{ + using DataType = remove_cvref_t; + using StaticTileDistribution = remove_cvref_t; + + static_assert(StaticTileDistribution::is_static(), + "wrong! StaticTileDistribution should be known at compile tile"); + + using ThreadTensorDesc = + remove_cvref_t; + + static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension() + { + return StaticTileDistribution::get_num_of_dimension_x(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_lengths() + { + return StaticTileDistribution::get_lengths(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_tile_distribution() + { + return StaticTileDistribution{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() + { + return StaticTileDistribution::get_distributed_spans(); + } + + CK_TILE_HOST_DEVICE void initialize(const DataType& x) { thread_buf_.initialize(x); } + + CK_TILE_HOST_DEVICE constexpr const auto& get_thread_buffer() const { return thread_buf_; } + + CK_TILE_HOST_DEVICE constexpr auto& get_thread_buffer() { return thread_buf_; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_thread_buffer_size() + { + return kThreadElementSpaceSize; + } + + template + CK_TILE_HOST_DEVICE auto get_y_sliced_thread_data(sequence, + sequence) const + { + static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && + sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, + "wrong!"); + + constexpr auto sliced_thread_tensor_desc = + make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); + + thread_buffer + sliced_thread_data; + + static_ford>{}([&](auto idx) { + constexpr auto idx_ys = idx + sequence{}; + + sliced_thread_data(number{}) = + thread_buf_[number{}]; + }); + + return sliced_thread_data; + } + + template + CK_TILE_HOST_DEVICE void set_y_sliced_thread_data(sequence, + sequence, + const SlicedThreadData& sliced_thread_data) + { + static_assert(sizeof...(YSliceOrigins) == StaticTileDistribution::NDimY && + sizeof...(YSliceLengths) == StaticTileDistribution::NDimY, + "wrong!"); + + constexpr auto sliced_thread_tensor_desc = + make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); + + static_ford>{}([&](auto idx) { + constexpr auto idx_ys = idx + sequence{}; + + thread_buf_(number{}) = + sliced_thread_data[number{}]; + }); + } + + template + CK_TILE_HOST_DEVICE constexpr const DataType& operator[](TileDistributedIndices) const + { + static_assert(is_static_v, + "wrong! Tile Distributed Indices should be static"); + + constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( + TileDistributedIndices{}); + + return thread_buf_[number{}]; + } + + template + CK_TILE_HOST_DEVICE constexpr DataType& operator()(TileDistributedIndices) + { + static_assert(is_static_v, + "wrong! Tile Distributed Indices should be static"); + + constexpr auto y_idx = get_tile_distribution().get_y_indices_from_distributed_indices( + TileDistributedIndices{}); + + return thread_buf_(number{}); + } + + // + thread_buffer thread_buf_; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&) +{ + return static_distributed_tensor, + remove_cvref_t>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution&, + ThreadBuffer&& thread_buffer_) +{ + return static_distributed_tensor, + remove_cvref_t>{thread_buffer_}; +} + +// get X indices from tuple of tile_distributed_index<> +template +CK_TILE_HOST_DEVICE constexpr auto +get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, + DistributedIndices distributed_indices) +{ + const auto partition_index = detail::get_partition_index(tile_distribution); + constexpr auto y_indices = + tile_distribution.get_y_indices_from_distributed_indices(distributed_indices); + + const auto x_coord = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(partition_index, to_array(y_indices))); + + return x_coord.get_bottom_index(); +} + +template +CK_TILE_HOST_DEVICE void +set_tile_if(static_distributed_tensor& out_tensor, + DataType value, + XIndicesPredicate predicate) +{ + constexpr auto out_spans = + static_distributed_tensor::get_distributed_spans(); + sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) { + constexpr auto distributed_indices = make_tuple(idx0, idx1); + const auto x_indices = get_x_indices_from_distributed_indices(StaticTileDistribution{}, + distributed_indices); + + if(predicate(x_indices)) + { + out_tensor(distributed_indices) = value; + } + }); + }); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/store_tile.hpp b/include/ck_tile/core/tensor/store_tile.hpp new file mode 100644 index 000000000..c12ad883d --- /dev/null +++ b/include/ck_tile/core/tensor/store_tile.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +CK_TILE_DEVICE void +store_tile(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.store(dstr_tensor); +} + +template +CK_TILE_DEVICE void +store_tile_raw(tile_window_with_static_lengths& tile_window_tmp, + const static_distributed_tensor& dstr_tensor) +{ + using DataType = remove_cvref_t; + using TileDstr = remove_cvref_t; + + static_assert(std::is_same_v, DataType>, "wrong!"); + + constexpr auto tile_dstr = TileDstr{}; + + auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(), + tile_window_tmp.get_window_lengths(), + tile_window_tmp.get_window_origin(), + tile_dstr); + + tile_window.store_raw(dstr_tensor); +} + +template +CK_TILE_DEVICE void +store_tile(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.store(dstr_tensor); +} + +template +CK_TILE_DEVICE void +store_tile_raw(tile_window_with_static_distribution& tile_window, + const static_distributed_tensor& dstr_tensor) +{ + tile_window.store_raw(dstr_tensor); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/sweep_tile.hpp b/include/ck_tile/core/tensor/sweep_tile.hpp new file mode 100644 index 000000000..f1511f11d --- /dev/null +++ b/include/ck_tile/core/tensor/sweep_tile.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// sweep over a span of a distribted tile and apply lambda function F +template + typename F // signature: F(tile_distributed_index<...>) + > +CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f) +{ + using DstrSpan = remove_cvref_t; + + static_ford{}([&](auto dstr_idx_impl) { + constexpr auto dstr_idx = detail::make_tile_distributed_index(dstr_idx_impl); + + f(dstr_idx); + }); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp new file mode 100644 index 000000000..6bcba4019 --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -0,0 +1,945 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "ck_tile/core/numeric/numeric.hpp" + +namespace ck_tile { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// BottomDimensionHiddenIds : Sequence<...> +// TopDimensionHiddenIds : Sequence<...> +template +struct tensor_adaptor +{ + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_transform() + { + return Transforms::size(); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const { return transforms_; } + + CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss() + { + return LowerDimensionHiddenIdss{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss() + { + return UpperDimensionHiddenIdss{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_bottom_dimension_hidden_ids() + { + return BottomDimensionHiddenIds{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids() + { + return TopDimensionHiddenIds{}; + } + + CK_TILE_HOST_DEVICE static constexpr auto initialize_element_size(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_top) { + constexpr index_t idim_hidden = TopDimensionHiddenIds::at(idim_top); + + constexpr auto tmp = get_transform_and_its_upper_dimension(number{}); + + constexpr index_t itran = tmp[number<0>{}]; + constexpr index_t idim_up = tmp[number<1>{}]; + constexpr bool found = tmp[number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[number{}].get_upper_lengths()[number{}]; + + return length; + }, + number{}); + + // TODO: make container_reduce support tuple of number and index_t + return container_reduce(lengths, multiplies{}, number<1>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto + get_transform_and_its_upper_dimension(number) + { + // FIXME: length of bottom dimension is not known, since info about lower dim length are not + // saved in transformation + static_assert(IDimHidden >= ndim_bottom_, "wrong! not implemented"); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; + + static_for<0, up_dim_ids.size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == IDimHidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_bottom_dimension() + { + return BottomDimensionHiddenIds::size(); + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_top_dimension() + { + return TopDimensionHiddenIds::size(); + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + equal>::type; + + return unique_sort_all_dim_ids::size(); + } + + constexpr static index_t ntransform_ = get_num_of_transform(); + constexpr static index_t ndim_hidden_ = get_num_of_hidden_dimension(); + constexpr static index_t ndim_bottom_ = get_num_of_bottom_dimension(); + constexpr static index_t ndim_top_ = get_num_of_top_dimension(); + + using HiddenIndex = multi_index; + using BottomIndex = multi_index; + using TopIndex = multi_index; + + // may be index_t or number<> + using ElementSize = remove_cv_t; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_adaptor() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_adaptor(const Transforms& transforms) + : transforms_{transforms}, element_size_{initialize_element_size(transforms)} + { + static_assert(Transforms::size() == ntransform_ && + LowerDimensionHiddenIdss::size() == ntransform_ && + UpperDimensionHiddenIdss::size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + CK_TILE_HOST_DEVICE constexpr auto get_element_size() const { return element_size_; } + + // FIXME: this logic is wrong when getting bottome dimension lengths + template + CK_TILE_HOST_DEVICE constexpr auto get_hidden_dimension_length(number) const + { + static_assert(IDimHidden >= 0 && IDimHidden < ndim_hidden_, "wrong! out of range"); + + constexpr auto tmp = get_transform_and_its_upper_dimension(number{}); + + constexpr index_t itran = tmp[number<0>{}]; + constexpr index_t idim_up = tmp[number<1>{}]; + constexpr bool found = tmp[number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + return transforms_[number{}].get_upper_lengths()[number{}]; + } + + template + CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_length(number idim_top) const + { + return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_top)); + } + +#if 0 + // FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths + template + CK_TILE_HOST_DEVICE constexpr index_t + get_bottom_dimension_length(number idim_bottom) const + { + return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom)); + } +#endif + + CK_TILE_HOST_DEVICE constexpr auto get_top_dimension_lengths() const + { + return generate_tuple([&](auto i) { return get_top_dimension_length(i); }, + number{}); + } + +#if 0 + // FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths + CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const + { + return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); }, + number{}); + } +#endif + + template + CK_TILE_HOST_DEVICE constexpr auto calculate_bottom_index(const TopIdx& idx_top) const + { + static_assert(TopIdx::size() == TopDimensionHiddenIds::size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = get_num_of_transform(); + constexpr index_t ndim_hidden = get_num_of_hidden_dimension(); + + multi_index idx_hidden; + + // initialize uppest index + set_container_subset(idx_hidden, get_top_dimension_hidden_ids(), idx_top); + + // calculate hidden index + static_for{}([&](auto itran_p1) { + auto itran = itran_p1 - number<1>{}; + const auto& tran = get_transforms().at(itran); + constexpr auto dims_low = get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = get_upper_dimension_hidden_idss().at(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + multi_index idx_low; + + tran.calculate_lower_index(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + bool is_known = true; + + static_for<0, Transforms::size(), 1>{}([&](auto i) { + is_known &= remove_cvref_t::is_known_at_compile_time(); + }); + + return is_known && ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides( + const array& guaranteed_vector_lengths, + const array& guaranteed_vector_strides) + { + auto vector_lengths = guaranteed_vector_lengths; + auto vector_strides = guaranteed_vector_strides; + + static_for<0, get_num_of_transform(), 1>{}([&](auto itran) { + constexpr auto low_dims = get_lower_dimension_hidden_idss().at(itran); + constexpr auto up_dims = get_upper_dimension_hidden_idss().at(itran); + + const auto up_guaranteed_vector_lengths = + get_container_subset(guaranteed_vector_lengths, up_dims); + const auto up_guaranteed_vector_strides = + get_container_subset(guaranteed_vector_strides, up_dims); + + // only need type of transform + auto [up_vector_lengths, up_vector_strides] = + Transforms{}.at(itran).calculate_upper_dimension_safe_vector_length_strides( + get_container_subset(vector_lengths, low_dims), + get_container_subset(vector_strides, low_dims)); + + if constexpr(up_dims.size() > 0) + { + for(index_t i = 0; i < up_dims.size(); ++i) + { + up_vector_lengths(i) = (up_guaranteed_vector_lengths[i] != -1) + ? up_guaranteed_vector_lengths[i] + : up_vector_lengths[i]; + + up_vector_strides(i) = (up_guaranteed_vector_strides[i] != -1) + ? up_guaranteed_vector_strides[i] + : up_vector_strides[i]; + } + } + + set_container_subset(vector_lengths, up_dims, up_vector_lengths); + set_container_subset(vector_strides, up_dims, up_vector_strides); + }); + + constexpr auto top_dims = TopDimensionHiddenIds{}; + + return make_tuple(get_container_subset(vector_lengths, top_dims), + get_container_subset(vector_strides, top_dims)); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tensor_adaptor{"); + + // + printf("transforms: "); + print(transforms_); + printf(", "); + + // + printf("LowerDimensionHiddenIds: "); + print(LowerDimensionHiddenIdss{}); + printf(", "); + + // + printf("UpperDimensionHiddenIds: "); + print(UpperDimensionHiddenIdss{}); + printf(", "); + + // + printf("BottomDimensionHiddenIds: "); + print(BottomDimensionHiddenIds{}); + printf(", "); + + // + printf("TopDimensionHiddenIds: "); + print(TopDimensionHiddenIds{}); + + printf("}"); + } + + private: + Transforms transforms_; + ElementSize element_size_; +}; + +// Transforms: Tuple +// LowerDimensionOldTopIdss: Tuple, ...> +// UpperDimensionNewTopIdss: Tuple, ...> +template +CK_TILE_HOST_DEVICE constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms, + LowerDimensionOldTopIdss, + UpperDimensionNewTopIdss) +{ + constexpr index_t ntransform = Transforms::size(); + + static_assert(LowerDimensionOldTopIdss::size() == ntransform && + UpperDimensionNewTopIdss::size() == ntransform, + "wrong!"); + + // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss + constexpr auto all_low_dim_old_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{}); + + constexpr auto all_up_dim_new_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + + constexpr index_t ndim_old_top = all_low_dim_old_top_ids.size(); + constexpr index_t ndim_new_top = all_up_dim_new_top_ids.size(); + + // low_dim_hidden_idss + constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{}; + + // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom + constexpr auto up_dim_hidden_idss = generate_tuple( + [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + number{}; }, + number{}); + + // bottom_dim_hidden_ids + constexpr auto bottom_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{}; + + // top_dim_hidden_ids + constexpr auto top_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + number{}; + + return tensor_adaptor, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{transforms}; +} + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor, and to put it outside the scope where it is used +// (transform_tensor_adaptor) because template cannot be defined inside a function +// template +template +struct lambda_get_up_dim_num +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(I) const + { + using Tran = remove_reference_t; + return number{}; + } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor, + const NewTransforms& new_transforms, + NewLowerDimensionOldTopIdss, + NewUpperDimensionNewTopIdss) +{ + // sanity check + { + static_assert(NewTransforms::size() == NewLowerDimensionOldTopIdss::size() && + NewTransforms::size() == NewUpperDimensionNewTopIdss::size(), + "wrong! inconsitent number of transform"); + + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewLowerDimensionOldTopIdss{}); + + constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewUpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + } + + // lower dimension's hidden idss + // convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of + // sequences) + constexpr auto low_dim_hidden_idss = transform_tuples( + // convert lower dimension top ids (a sequence) to hidden ids (a sequence) + [](auto low_dim_top_ids) constexpr { + return transform_sequences( + // convert lower dimension top id to hidden id + [](auto low_dim_top_id) constexpr { + return OldTensorAdaptor::get_top_dimension_hidden_ids()[low_dim_top_id]; + }, + low_dim_top_ids); + }, + NewLowerDimensionOldTopIdss{}); + + constexpr index_t num_new_transform = NewTransforms::size(); + + // upper dimension's hidden idss + constexpr index_t old_hidden_dim_number = OldTensorAdaptor::get_num_of_hidden_dimension(); + + constexpr auto up_dim_numbers = + generate_sequence(lambda_get_up_dim_num{}, number{}); + + constexpr auto up_dim_numbers_scan = merge_sequences( + sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus{}, number<0>{})); + + constexpr auto up_dim_hidden_idss = generate_tuple( + [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + return + typename arithmetic_sequence_gen::type{}; + }, + number{}); + + // new top dimension's hidden ids + constexpr auto unordered_new_top_dim_hidden_ids = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + + constexpr auto new_top_dim_unordered2ordered = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{}); + + constexpr auto new_top_dim_hidden_ids = + unordered_new_top_dim_hidden_ids.reorder_old_to_new(new_top_dim_unordered2ordered); + + // put everything together + const auto all_transforms = + container_concat(old_tensor_adaptor.get_transforms(), new_transforms); + + constexpr auto all_low_dim_hidden_idss = + container_concat(OldTensorAdaptor::get_lower_dimension_hidden_idss(), low_dim_hidden_idss); + + constexpr auto all_up_dim_hidden_idss = + container_concat(OldTensorAdaptor::get_upper_dimension_hidden_idss(), up_dim_hidden_idss); + + return tensor_adaptor< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{all_transforms}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, + const TensorAdaptor1& adaptor1) +{ + static_assert(TensorAdaptor0::get_num_of_top_dimension() == + TensorAdaptor1::get_num_of_bottom_dimension(), + "wrong!"); + + // all_transforms = transform0 + transform1 + const auto all_transforms = + container_concat(adaptor0.get_transforms(), adaptor1.get_transforms()); + + // shift + constexpr index_t adaptor0_max_hidden_id = [&]() { + index_t adaptor0_max_hidden_id_ = numeric::min(); + + static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor0{}.get_transforms()[itran].get_num_of_lower_dimension(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + adaptor0_max_hidden_id_ = + max(adaptor0_max_hidden_id_, + TensorAdaptor0::get_lower_dimension_hidden_idss()[itran][idim_low].value); + }); + + constexpr index_t ndim_up = + TensorAdaptor0{}.get_transforms()[itran].get_num_of_upper_dimension(); + + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor0_max_hidden_id_ = + max(adaptor0_max_hidden_id_, + TensorAdaptor0::get_upper_dimension_hidden_idss()[itran][idim_up].value); + }); + }); + + return adaptor0_max_hidden_id_; + }(); + + constexpr index_t adaptor1_min_hidden_id = [&]() { + index_t adaptor1_min_hidden_id_ = numeric::max(); + + static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor1{}.get_transforms()[itran].get_num_of_lower_dimension(); + + // get the min of all lower dimenions, but not bottom dimension (because their id will + // be matched with top id from adaptor0) + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t low_dim_hidden_id = + TensorAdaptor1::get_lower_dimension_hidden_idss()[itran][idim_low].value; + + bool is_bottom_dim = false; + static_for<0, TensorAdaptor1::get_num_of_bottom_dimension(), 1>{}([&](auto i) { + if constexpr(low_dim_hidden_id == + TensorAdaptor1::get_bottom_dimension_hidden_ids()[i]) + { + is_bottom_dim = true; + } + }); + + if(!is_bottom_dim) + { + adaptor1_min_hidden_id_ = min(adaptor1_min_hidden_id_, low_dim_hidden_id); + } + }); + + constexpr index_t ndim_up = + TensorAdaptor1{}.get_transforms()[itran].get_num_of_upper_dimension(); + + // get the min of all upper dimensions + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor1_min_hidden_id_ = + min(adaptor1_min_hidden_id_, + TensorAdaptor1::get_upper_dimension_hidden_idss()[itran][idim_up].value); + }); + }); + + return adaptor1_min_hidden_id_; + }(); + + constexpr index_t adaptor1_hidden_id_shift = + adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id; + + constexpr index_t ndim_bottom_1 = TensorAdaptor1::get_num_of_bottom_dimension(); + + // all_low_dim_hidden_idss = + // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1)) + constexpr auto low_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_low_1 = + TensorAdaptor1::get_lower_dimension_hidden_idss()[itran].size(); + + constexpr auto low_dim_hidden_ids_1 = + TensorAdaptor1::get_lower_dimension_hidden_idss()[itran]; + + // sequence in, sequence out + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr + { + auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); + + // shift hidden id so every dim id is unique + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift; + }); + + // match hidden id + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::get_bottom_dimension_hidden_ids() + [idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::get_top_dimension_hidden_ids()[idim_bottom_1]; + } + }); + }); + + return low_dim_hidden_ids_1_mod_; + } + (); + + return generate_sequence_v2( + [&](auto i) constexpr { return number{}; }, + number{}); + }, + number{}); + + constexpr auto all_low_dim_hidden_idss = + container_concat(TensorAdaptor0::get_lower_dimension_hidden_idss(), low_dim_hidden_idss_1); + + // all_up_dim_hidden_idss = + // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1) + constexpr auto up_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_up_1 = + TensorAdaptor1::get_upper_dimension_hidden_idss()[itran].size(); + + constexpr auto up_dim_hidden_ids_1 = + TensorAdaptor1::get_upper_dimension_hidden_idss()[itran]; + + // sequence in, constexpr tuple out + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr + { + auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); + + // shift hidden id + static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { + up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift; + }); + + return up_dim_hidden_ids_1_mod_; + } + (); + + // constexpr tuple to sequence + return generate_sequence_v2( + [&](auto i) constexpr { return number{}; }, + number{}); + }, + number{}); + + constexpr auto all_up_dim_hidden_idss = + container_concat(TensorAdaptor0::get_upper_dimension_hidden_idss(), up_dim_hidden_idss_1); + + // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 + constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::get_bottom_dimension_hidden_ids(); + + // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) + constexpr auto top_dim_hidden_ids = + TensorAdaptor1::get_top_dimension_hidden_ids() + number{}; + + // put everything together + return tensor_adaptor, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{all_transforms}; +} + +template = 2, bool>::type = false> +CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) +{ + return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); +} + +} // namespace ck_tile + +// Macro function +// construct constexpr tensor_adaptor from constexpr encoding +// encoded_tensor_adaptor are Tuple of following objects: +// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following: +// 1.1 name (coord_transform_enum) +// 1.2 meta data for constructor of the transform +// 1.3 num of lower dimension (index_t) +// 1.4 lower dimension Ids (array of fixed size) +// 1.5 num of up dimension (index_t) +// 1.6 upper dimension Ids (array of fixed size) +// 2. num of transforms (index_t) +// 3. encoded bottom dimension Ids (array of fixed size) +// 4. num of bottom dimension (index_t) +// 5. encoded top dimension Ids (array of fixed size) +// 6. num of top dimension (index_t) +#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \ + [encoded_tensor_adaptor]() { \ + using namespace ck_tile; \ + \ + constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \ + constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \ + constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \ + constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \ + constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ + constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ + \ + constexpr auto trans = [&encoded_transforms]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) constexpr { \ + constexpr auto name = encoded_transforms[i].template at<0>(); \ + constexpr auto meta_data = encoded_transforms[i].template at<1>(); \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + \ + static_assert(name == coord_transform_enum::pass_through || \ + name == coord_transform_enum::pad || \ + name == coord_transform_enum::embed || \ + name == coord_transform_enum::merge || \ + name == coord_transform_enum::unmerge || \ + name == coord_transform_enum::replicate, \ + ""); \ + \ + if constexpr(name == coord_transform_enum::pass_through) \ + { \ + index_t pos = 0; \ + auto low_len = meta_data.template pop(pos); \ + \ + return make_pass_through_transform(low_len); \ + } \ + else if constexpr(name == coord_transform_enum::pad) \ + { \ + index_t pos = 0; \ + auto low_len = meta_data.template pop(pos); \ + auto left_pad = meta_data.template pop(pos); \ + auto right_pad = meta_data.template pop(pos); \ + \ + return make_pad_transform(low_len, left_pad, right_pad); \ + } \ + else if constexpr(name == coord_transform_enum::embed) \ + { \ + index_t pos = 0; \ + auto up_lens = meta_data.template pop>(pos); \ + auto coefficients = \ + meta_data.template pop>(pos); \ + \ + return make_embed_transform(up_lens, coefficients); \ + } \ + else if constexpr(name == coord_transform_enum::merge) \ + { \ + index_t pos = 0; \ + auto low_lens = meta_data.template pop>(pos); \ + \ + return make_merge_transform(low_lens); \ + } \ + else if constexpr(name == coord_transform_enum::unmerge) \ + { \ + index_t pos = 0; \ + auto up_lens = meta_data.template pop>(pos); \ + \ + return make_unmerge_transform(up_lens); \ + } \ + else if constexpr(name == coord_transform_enum::replicate) \ + { \ + index_t pos = 0; \ + auto up_lens = meta_data.template pop>(pos); \ + \ + return make_replicate_transform(up_lens); \ + } \ + }, \ + number{}); \ + }(); \ + \ + constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto low_dims = encoded_transforms[i].template at<3>(); \ + \ + return TO_SEQUENCE(low_dims, num_low_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + constexpr auto up_dims = encoded_transforms[i].template at<5>(); \ + \ + return TO_SEQUENCE(up_dims, num_up_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \ + constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \ + \ + return tensor_adaptor, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t>{trans}; \ + }() + +// Macro function +// construct static tensor_adaptor from constexpr encoding +// encoded_tensor_adaptor are Tuple of following objects: +// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following: +// 1.1 name (coord_transform_enum) +// 1.2 meta data for constructor of the transform +// 1.3 num of lower dimension (index_t) +// 1.4 lower dimension Ids (array of fixed size) +// 1.5 num of up dimension (index_t) +// 1.6 upper dimension Ids (array of fixed size) +// 2. num of transforms (index_t) +// 3. encoded bottom dimension Ids (array of fixed size) +// 4. num of bottom dimension (index_t) +// 5. encoded top dimension Ids (array of fixed size) +// 6. num of top dimension (index_t) +#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \ + [encoded_tensor_adaptor]() { \ + using namespace ck_tile; \ + \ + constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \ + constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \ + constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \ + constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \ + constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \ + constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \ + \ + constexpr auto trans = [&encoded_transforms]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) constexpr { \ + constexpr auto name = encoded_transforms[i].template at<0>(); \ + constexpr auto meta_data = encoded_transforms[i].template at<1>(); \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + \ + static_assert(name == coord_transform_enum::pass_through || \ + name == coord_transform_enum::pad || \ + name == coord_transform_enum::embed || \ + name == coord_transform_enum::merge || \ + name == coord_transform_enum::unmerge || \ + name == coord_transform_enum::replicate, \ + ""); \ + \ + if constexpr(name == coord_transform_enum::pass_through) \ + { \ + constexpr index_t low_len = meta_data.template get(0); \ + \ + return make_pass_through_transform(number{}); \ + } \ + else if constexpr(name == coord_transform_enum::pad) \ + { \ + constexpr index_t low_len = meta_data.template get(0); \ + \ + constexpr index_t left_pad = \ + meta_data.template get(sizeof(low_len)); \ + \ + constexpr index_t right_pad = \ + meta_data.template pop(sizeof(low_len) + sizeof(left_pad)); \ + \ + return make_pad_transform( \ + number{}, number{}, number{}); \ + } \ + else if constexpr(name == coord_transform_enum::embed) \ + { \ + constexpr auto up_lens = \ + meta_data.template get>(0); \ + \ + constexpr auto coefficients = \ + meta_data.template get>(sizeof(up_lens)); \ + \ + return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \ + TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \ + } \ + else if constexpr(name == coord_transform_enum::merge) \ + { \ + constexpr auto low_lens = \ + meta_data.template get>(0); \ + \ + return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \ + } \ + else if constexpr(name == coord_transform_enum::unmerge) \ + { \ + constexpr auto up_lens = \ + meta_data.template get>(0); \ + \ + return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \ + } \ + else if constexpr(name == coord_transform_enum::replicate) \ + { \ + constexpr auto up_lens = \ + meta_data.template get>(0); \ + \ + return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \ + } \ + }, \ + number{}); \ + }(); \ + \ + constexpr auto low_dim_idss = [&encoded_transforms]() { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \ + constexpr auto low_dims = encoded_transforms[i].template at<3>(); \ + \ + return TO_SEQUENCE(low_dims, num_low_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto up_dim_idss = [&encoded_transforms] { \ + return generate_tuple( \ + [&encoded_transforms](auto i) { \ + constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \ + constexpr auto up_dims = encoded_transforms[i].template at<5>(); \ + \ + return TO_SEQUENCE(up_dims, num_up_dim); \ + }, \ + number()); \ + }(); \ + \ + constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \ + constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \ + \ + return tensor_adaptor, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t, \ + remove_cvref_t>{trans}; \ + }() diff --git a/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp new file mode 100644 index 000000000..0d398d423 --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp @@ -0,0 +1,257 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tensor_adaptor_coordinate +{ + static constexpr index_t ndim_bottom_ = BottomDimensionHiddenIds::size(); + static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size(); + + using HiddenIndex = multi_index; + using BottomIndex = multi_index; + using TopIndex = multi_index; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_adaptor_coordinate(const HiddenIndex& idx_hidden) + : idx_hidden_{idx_hidden} + { + } + + CK_TILE_HOST_DEVICE constexpr auto get_top_index() const + { + return get_container_subset(idx_hidden_, TopDimensionHiddenIds{}); + } + + CK_TILE_HOST_DEVICE constexpr auto get_bottom_index() const + { + return get_container_subset(idx_hidden_, BottomDimensionHiddenIds{}); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const { return idx_hidden_; } + + CK_TILE_HOST_DEVICE constexpr auto& get_hidden_index() { return idx_hidden_; } + + // + HiddenIndex idx_hidden_; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_tensor_adaptor_coordinate(const Adaptor& adaptor, + const TopIndex& idx_top) +{ + static_assert(Adaptor::get_num_of_top_dimension() == TopIndex::size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = Adaptor::get_num_of_transform(); + constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension(); + constexpr auto bottom_dim_ids = Adaptor::get_bottom_dimension_hidden_ids(); + constexpr auto top_dim_ids = Adaptor::get_top_dimension_hidden_ids(); + + multi_index idx_hidden; + + // initialize visible index + set_container_subset(idx_hidden, top_dim_ids, idx_top); + + // calculate hidden index + static_for{}([&adaptor, &idx_hidden](auto itran_p1) { + auto itran = itran_p1 - number<1>{}; + const auto& tran = adaptor.get_transforms().at(itran); + constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + multi_index idx_low; + + tran.calculate_lower_index(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return tensor_adaptor_coordinate, + remove_cvref_t>{idx_hidden}; +} + +template +CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor, + AdaptorCoord& coord, + const TopIndex& idx_diff_top, + BottomIndex& idx_diff_bottom) +{ + constexpr index_t ndim_hidden = Adaptor::get_num_of_hidden_dimension(); + constexpr index_t ndim_top = Adaptor::get_num_of_top_dimension(); + // constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension(); + constexpr index_t ntransform = Adaptor::get_num_of_transform(); + + // static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, ""); + + // judge whether calculation of lower diff is needed for each transform + // use index_t for boolean type + auto do_transforms = make_zero_multi_index(); + + if constexpr(JudgeDoTransforms) + { + auto is_non_zero_diff = make_zero_multi_index(); + + // decide do_transform by checkout non-zero index diff components + multi_index non_zero_diff_pick_top; + + static_for<0, ndim_top, 1>{}( + [&](auto i) { non_zero_diff_pick_top(i) = (idx_diff_top[i] != 0); }); + + set_container_subset( + is_non_zero_diff, Adaptor::get_top_dimension_hidden_ids(), non_zero_diff_pick_top); + + static_for{}([&](auto itran) { + constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran); + + const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); + + multi_index non_zero_diff_pick_low; + + // if any of upper index diff components is non-zero, then + // 1) Need to do this transform + // 2) all components of lower index diff will assume to be non-zero and need to be + // computed + const bool idx_diff_up_has_non_zero = container_reduce( + non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false); + + do_transforms(itran) = idx_diff_up_has_non_zero; + + static_for<0, dims_low.size(), 1>{}( + [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; }); + + set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); + }); + } + else + { + static_for{}([&](auto itran) { do_transforms(itran) = 1; }); + } + + // this is what needs to be calculated + auto idx_diff_hidden = make_zero_multi_index(); + + // initialize top index diff + set_container_subset(idx_diff_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_diff_top); + + // this is what needs to be updated + auto& idx_hidden = coord.get_hidden_index(); + + // update top index + auto idx_hidden_pick_top = + get_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids()); + + idx_hidden_pick_top += idx_diff_top; + + set_container_subset(idx_hidden, Adaptor::get_top_dimension_hidden_ids(), idx_hidden_pick_top); + + // update rest of hidden index + static_for{}([&](auto itran) { + if(do_transforms[itran]) + { + const auto& tran = adaptor.get_transforms().at(itran); + constexpr auto dims_low = Adaptor::get_lower_dimension_hidden_idss().at(itran); + constexpr auto dims_up = Adaptor::get_upper_dimension_hidden_idss().at(itran); + + const auto idx_up_new = get_container_subset(idx_hidden, dims_up); + auto idx_low = get_container_subset(idx_hidden, dims_low); + const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); + + multi_index idx_diff_low; + + tran.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up_new); + + set_container_subset(idx_diff_hidden, dims_low, idx_diff_low); + set_container_subset(idx_hidden, dims_low, idx_low); + } + }); + + // set bottom index diff + idx_diff_bottom = + get_container_subset(idx_diff_hidden, Adaptor::get_bottom_dimension_hidden_ids()); +} + +template +CK_TILE_HOST_DEVICE constexpr void move_tensor_adaptor_coordinate(const Adaptor& adaptor, + AdaptorCoord& coord, + const TopIndex& idx_diff_top) +{ + constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension(); + + multi_index tmp; + + move_tensor_adaptor_coordinate(adaptor, coord, idx_diff_top, tmp); +} + +template +CK_TILE_HOST_DEVICE constexpr bool +adaptor_coordinate_is_valid_assuming_top_index_is_valid(const Adaptor& adaptor, + const AdaptorCoord& coord) +{ + bool valid = true; + + constexpr index_t ntransform = Adaptor::get_num_of_transform(); + + const auto& idx_hidden = coord.get_hidden_index(); + + static_for{}([&adaptor, &idx_hidden, &valid](auto itran) { + const auto tran = adaptor.get_transforms().at(itran); + + // check validity, only if current transformation does not always has a valid mapping + if constexpr(!decltype(tran)::is_valid_upper_index_always_mapped_to_valid_lower_index()) + { + const auto idx_up = get_container_subset( + idx_hidden, Adaptor::get_upper_dimension_hidden_idss().at(itran)); + + // Comment: using valid = valid && .. will result in weird control flow in ISA + valid &= tran.is_valid_upper_index_mapped_to_valid_lower_index(idx_up); + } + }); + + return valid; +} + +template +CK_TILE_HOST_DEVICE constexpr bool adaptor_coordinate_is_valid(const Adaptor& adaptor, + const AdpatorCoord& coord) +{ + // check top index + const auto& idx_top = coord.get_top_index(); + + bool is_top_index_valid = true; + + static_for<0, Adaptor::get_num_of_dimension(), 1>{}( + [&is_top_index_valid, &idx_top, &adaptor](auto i) { + is_top_index_valid = + is_top_index_valid && (idx_top[i] >= 0 && idx_top[i] < adaptor.get_length(i)); + }); + + // check other hidden index + return is_top_index_valid && + adaptor_coordinate_is_valid_assuming_top_index_is_valid(adaptor, coord); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_coordinate.hpp b/include/ck_tile/core/tensor/tensor_coordinate.hpp new file mode 100644 index 000000000..9b8fe731f --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_coordinate.hpp @@ -0,0 +1,92 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tensor_coordinate + : public tensor_adaptor_coordinate, TopDimensionHiddenIds> +{ + using Base = tensor_adaptor_coordinate, TopDimensionHiddenIds>; + + // TODO make these private + static constexpr index_t ndim_top_ = TopDimensionHiddenIds::size(); + + using HiddenIndex = multi_index; + using TopIndex = multi_index; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_coordinate() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const HiddenIndex& idx_hidden) + : Base{idx_hidden} + { + } + + // construct from TensorAdaptorCoordinte base class + CK_TILE_HOST_DEVICE constexpr tensor_coordinate(const Base& adaptor_coord) : Base{adaptor_coord} + { + } + + CK_TILE_HOST_DEVICE constexpr auto get_index() const { return Base::get_top_index(); } + + CK_TILE_HOST_DEVICE constexpr index_t get_offset() const + { + return Base::get_bottom_index()[number<0>{}]; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_hidden_index() const + { + return Base::get_hidden_index(); + } + + CK_TILE_HOST_DEVICE auto& get_hidden_index() { return Base::get_hidden_index(); } +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc, + const TopIndex& idx_top) +{ + const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top); + + return tensor_coordinate>{ + adaptor_coord}; +} + +template +CK_TILE_HOST_DEVICE constexpr void +move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step) +{ + move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step); +} + +template +CK_TILE_HOST_DEVICE constexpr bool +coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord); +} + +template +CK_TILE_HOST_DEVICE constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + return adaptor_coordinate_is_valid(tensor_desc, coord); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_descriptor.hpp b/include/ck_tile/core/tensor/tensor_descriptor.hpp new file mode 100644 index 000000000..0c3e04f31 --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_descriptor.hpp @@ -0,0 +1,467 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// TopDimensionHiddenIds> : sequence<...> +template +struct tensor_descriptor : public tensor_adaptor, + TopDimensionHiddenIds> +{ + using Base = tensor_adaptor, + TopDimensionHiddenIds>; + + using ElementSpaceSizeType = ElementSpaceSize; + + constexpr static index_t ntransform_ = Base::get_num_of_transform(); + constexpr static index_t ndim_hidden_ = Base::get_num_of_hidden_dimension(); + constexpr static index_t ndim_top_ = Base::get_num_of_top_dimension(); + + using GuaranteedVectorLengths = GuaranteedVectorLengths_; + using GuaranteedVectorStrides = GuaranteedVectorSrides_; + + static_assert(GuaranteedVectorLengths::size() == ndim_hidden_ && + GuaranteedVectorStrides::size() == ndim_hidden_, + "wrong! inconsistent # of hidden dimensions"); + + using TopIndex = multi_index; + using HiddenIndex = multi_index; + + public: + CK_TILE_HOST_DEVICE constexpr tensor_descriptor() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) + : Base{transforms}, element_space_size_{element_space_size} + + { + static_assert(Transforms::size() == ntransform_ && + LowerDimensionHiddenIdss::size() == ntransform_ && + UpperDimensionHiddenIdss::size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + // construct from tensor_adaptor base class + CK_TILE_HOST_DEVICE constexpr tensor_descriptor(const Base& adaptor, + ElementSpaceSize element_space_size) + : Base{adaptor}, element_space_size_{element_space_size} + { + } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() + { + return Base::get_num_of_top_dimension(); + } + + template + CK_TILE_HOST_DEVICE constexpr auto get_length(number idim) const + { + return Base::get_top_dimension_length(idim); + } + + CK_TILE_HOST_DEVICE constexpr auto get_lengths() const + { + return Base::get_top_dimension_lengths(); + } + + CK_TILE_HOST_DEVICE constexpr auto get_element_space_size() const + { + return element_space_size_; + } + + template + CK_TILE_HOST_DEVICE constexpr index_t calculate_offset(const Idx& idx) const + { + return Base::calculate_bottom_index(idx)[number<0>{}]; + } + + // TODO make these private + CK_TILE_HOST_DEVICE constexpr const auto& get_transforms() const + { + return Base::get_transforms(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_lower_dimension_hidden_idss() + { + return Base::get_lower_dimension_hidden_idss(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_upper_dimension_hidden_idss() + { + return Base::get_upper_dimension_hidden_idss(); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_hidden_ids() + { + return Base::get_top_dimension_hidden_ids(); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + return Base::is_known_at_compile_time() && + ck_tile::is_known_at_compile_time::value; + } + + CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time() { return is_static(); } + + CK_TILE_HOST_DEVICE static constexpr auto get_top_dimension_safe_vector_length_strides() + { + return Base::get_top_dimension_safe_vector_length_strides( + to_array(GuaranteedVectorLengths{}), + to_array(GuaranteedVectorStrides{})); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tensor_descriptor{"); + + // tensor_adaptor + Base::print(); + printf(", "); + + // element_space_size_ + printf("element_space_size_: "); + print(element_space_size_); + + printf("}"); + } + + // TODO make these private + ElementSpaceSize element_space_size_; +}; + +template +CK_TILE_HOST_DEVICE constexpr auto +make_tensor_descriptor_from_adaptor(const Adaptor& adaptor, + const ElementSpaceSize& element_space_size) +{ + constexpr index_t NDimHidden = Adaptor::get_num_of_hidden_dimension(); + + return tensor_descriptor, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + typename uniform_sequence_gen::type, + typename uniform_sequence_gen::type>{ + adaptor, element_space_size}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldTopIdss, + NewUpperDimensionNewTopIdss) +{ + const auto element_space_size = old_tensor_desc.get_element_space_size(); + + const auto new_tensor_adaptor = transform_tensor_adaptor(old_tensor_desc, + new_transforms, + NewLowerDimensionOldTopIdss{}, + NewUpperDimensionNewTopIdss{}); + + constexpr index_t NDimHiddenOld = OldTensorDescriptor::get_num_of_hidden_dimension(); + constexpr index_t NDimHiddenNew = decltype(new_tensor_adaptor)::get_num_of_hidden_dimension(); + + using NewGuaranteedVectorLengths = typename sequence_merge< + typename OldTensorDescriptor::GuaranteedVectorLengths, + typename uniform_sequence_gen::type>::type; + + using NewGuaranteedVectorStrides = typename sequence_merge< + typename OldTensorDescriptor::GuaranteedVectorStrides, + typename uniform_sequence_gen::type>::type; + + return tensor_descriptor< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + NewGuaranteedVectorLengths, + NewGuaranteedVectorStrides>{new_tensor_adaptor, element_space_size}; +} + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + number i, + AccOld acc_old) +{ + auto acc_new = acc_old + (lengths[i] - number<1>{}) * strides[i]; + + if constexpr(i.value < Lengths::size() - 1) + { + return calculate_element_space_size_impl(lengths, strides, i + number<1>{}, acc_new); + } + else + { + return acc_new; + } +} + +} // namespace detail + +/* + * These functions create naive tensor descriptor + */ + +// Lengths..., Strides... could be: +// 1) index_t, which is known at run-time, or +// 2) number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) long_number<> +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor(const tuple& lengths, + const tuple& strides, + number = number<-1>{}, + number = number<-1>{}) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_embed_transform(lengths, strides)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = + detail::calculate_element_space_size_impl(lengths, strides, number<0>{}, long_number<1>{}); + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, + sequence>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; +} + +// tensor descriptor with offset, the offset will not be added into element space size +// only have an information of the starting offset, and will impact on offset calculation +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor_with_offset(const tuple& lengths, + const tuple& strides, + const offset& os, + number = number<-1>{}, + number = number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = detail::calculate_element_space_size_impl( + lengths, strides, number<0>{}, long_number<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, os)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, + sequence>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_embed_transform(lengths, strides)), + make_tuple(sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) long_number<> +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor_packed(const tuple& lengths, + number = number<-1>{}) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_unmerge_transform(lengths)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{}); + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, sequence<1>>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; +} + +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offset( + const tuple& lengths, + const Offset& offset, + number = number<-1>{}) +{ + const auto desc_0 = [&]() { + const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{}); + + const auto transforms = make_tuple(make_offset_transform(element_space_size, offset)); + + constexpr auto low_dim_hidden_idss = make_tuple(sequence<0>{}); + + constexpr auto up_dim_hidden_idss = make_tuple(sequence<1>{}); + + constexpr auto visible_dim_hidden_ids = sequence<1>{}; + + using GuaranteedVectorLengths = + typename sequence_merge::type, + sequence>::type; + + using GuaranteedVectorStrides = + typename sequence_merge::type, sequence<1>>::type; + + return tensor_descriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t, + GuaranteedVectorLengths, + GuaranteedVectorStrides>{transforms, element_space_size}; + }(); + + constexpr index_t N = sizeof...(Lengths); + + return transform_tensor_descriptor( + desc_0, + make_tuple(make_unmerge_transform(lengths)), + make_tuple(sequence<0>{}), + make_tuple(typename arithmetic_sequence_gen<0, N, 1>::type{})); +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) number<>, which is known at compile-time +// align could be: +// 1) index_t, or +// 2) number<> +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_descriptor_aligned(const tuple& lengths, Align align) +{ + constexpr auto I1 = number<1>{}; + + constexpr index_t N = sizeof...(Lengths); + + const auto stride_n_minus_2 = integer_least_multiple(lengths[number{}], align); + + auto strides = generate_tuple( + [&](auto i) { + if constexpr(i.value == N - 1) + { + return I1; + } + else if constexpr(i.value == N - 2) + { + return number{}; + } + else + { + return container_reduce( + lengths, multiplies{}, number{}, i + I1, number{}, I1); + } + }, + number{}); + + return make_naive_tensor_descriptor(lengths, strides); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_view.hpp b/include/ck_tile/core/tensor/tensor_view.hpp new file mode 100644 index 000000000..e37bd806d --- /dev/null +++ b/include/ck_tile/core/tensor/tensor_view.hpp @@ -0,0 +1,281 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/tensor/tensor_descriptor.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tensor_view +{ + using buffer_view = remove_reference_t; + using DataType = typename buffer_view::type; + using TensorDesc = remove_cvref_t; + using TensorIndex = array; + using TensorCoord = decltype(make_tensor_coordinate(TensorDesc{}, TensorIndex{})); + + CK_TILE_HOST_DEVICE constexpr tensor_view() = default; + + CK_TILE_HOST_DEVICE constexpr tensor_view(const buffer_view& buffer_view, + const TensorDesc& desc) + : buf_{buffer_view}, desc_{desc} + { + } + + CK_TILE_HOST_DEVICE constexpr auto& get_tensor_descriptor() const { return desc_; } + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension() + { + return TensorDesc::get_num_of_top_dimension(); + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_buffer_view() const { return buf_; } + + CK_TILE_HOST_DEVICE constexpr auto& get_buffer_view() { return buf_; } + +#if 0 + CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const + { + return buf_.template get( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + } + + CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x) + { + buf_.template set( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } +#endif + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr remove_cvref_t + get_vectorized_elements(const TensorCoord& coord, + bool_constant = {}) const + { + return buf_.template get( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + bool_constant{}); + } + + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE void + get_vectorized_elements_raw(remove_cvref_t& dst, + const TensorCoord& coord, + bool_constant = {}) const + { + return buf_.template get_raw( + dst, + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord)); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void async_get_vectorized_elements(remove_cvref_t* smem, + const TensorCoord& coord) const + { + return buf_.template async_get(smem, coord.get_offset(), true /*not used*/); + } + + // X is vector of DataType. + // "coord" is coordinate of DataType, not X. "coord" should be aligned to X + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template set( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + + template >::scalar_type, + typename vector_traits>::scalar_type>, + bool>::type = false> + CK_TILE_HOST_DEVICE constexpr void set_vectorized_elements_raw( + const TensorCoord& coord, const X& x, bool_constant = {}) + { + buf_.template set_raw( + coord.get_offset(), + coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord), + x); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tensor_view{"); + + // buf_ + printf("buf_: "); + print(buf_); + printf(", "); + + // desc_ + printf("desc_: "); + print(desc_); + + printf("}"); + } + + // member + buffer_view buf_; + TensorDesc desc_; +}; + +// placeholder type if we want to opt-out a tile view parameter +struct null_tensor_view +{ +}; + +template +CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p, + const tensor_descriptor& desc) +{ + auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + + return tensor_view{buffer_view, desc}; +} + +template ::type = false> +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_view(DataType* p, + const tuple& lengths, + const tuple& strides, + number = number<-1>{}, + number = number<-1>{}) +{ + auto desc = make_naive_tensor_descriptor(lengths, + strides, + number{}, + number{}); + + auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + + return tensor_view{buffer_view, desc}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_naive_tensor_view_packed(DataType* p, + const tuple& lengths, + number = number<-1>{}) +{ + auto desc = + make_naive_tensor_descriptor_packed(lengths, number{}); + + auto buffer_view = make_buffer_view(p, desc.get_element_space_size()); + + return tensor_view{buffer_view, desc}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& old_tensor_view, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) +{ + auto new_desc = transform_tensor_descriptor(old_tensor_view.desc_, + new_transforms, + NewLowerDimensionOldVisibleIdss{}, + NewUpperDimensionNewVisibleIdss{}); + + return tensor_view>{ + old_tensor_view.buf_, new_desc}; +} + +template + typename DoPads> // sequence +CK_TILE_HOST_DEVICE constexpr auto +pad_tensor_view(const TensorView& tensor_view, const TileLengths& tile_lengths, DoPads) +{ + constexpr index_t num_dim = DoPads::size(); + + static_assert(num_dim == TileLengths::size() && num_dim == TensorView::get_num_of_dimension(), + "wrong! inconsistent # of dimensions"); + + // transforms + const auto transforms = generate_tuple( + [&](auto idim) { + const auto old_length = tensor_view.get_tensor_descriptor().get_length(idim); + + const auto tile_length = tile_lengths[idim]; + + const auto new_length = integer_divide_ceil(old_length, tile_length) * tile_length; + + const auto pad_length = new_length - old_length; + + constexpr bool DoPad = DoPads::at(idim); + + const auto transform = + conditional_expr(make_right_pad_transform(old_length, pad_length), + make_pass_through_transform(old_length)); + + return transform; + }, + number{}); + + // lower dimension Id + const auto lower_dimss = + generate_tuple([&](auto idim) { return sequence{}; }, number{}); + + // upper dimension Id + const auto upper_dimss = lower_dimss; + + return transform_tensor_view(tensor_view, transforms, lower_dimss, upper_dimss); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp new file mode 100644 index 000000000..9fee2fd5c --- /dev/null +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -0,0 +1,759 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution_encoding.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// distributed span +template +struct tile_distributed_span +{ + using Impl = sequence; + + static constexpr auto impl_ = Impl{}; + + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } +}; + +// distributed index +template +struct tile_distributed_index +{ + using Impl = sequence; + + static constexpr auto impl_ = Impl{}; + + CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; } +}; + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_span(sequence) +{ + return tile_distributed_span{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto make_tile_distributed_index(sequence) +{ + return tile_distributed_index{}; +} + +} // namespace detail + +template // FIXME: this is for hold ad-hoc but useful info, + // should be more elegnat +struct tile_distribution +{ + using PsYs2XsAdaptor = remove_cvref_t; + using Ys2DDescriptor = remove_cvref_t; + using DstrEncode = remove_cvref_t; + using DstrDetail = remove_cvref_t; + + static_assert(PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(), + "wrong! should be static"); + + static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension(); + static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension(); + static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY; + static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR; + + PsYs2XsAdaptor ps_ys_to_xs_; + Ys2DDescriptor ys_to_d_; + + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } + + CK_TILE_HOST_DEVICE static constexpr auto get_lengths() + { +#if 0 + // FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed + ps_ys_to_xs_.GetBottomDimensionLengths(); +#else + return generate_tuple( + [&](auto i) { + constexpr index_t x_length = + container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1); + + return number{}; + }, + number{}); +#endif + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const + { + return ps_ys_to_xs_; + } + + CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; } + + CK_TILE_HOST_DEVICE static constexpr auto get_static_tile_distribution_encoding() + { + return DstrEncode{}; + } + +#if 1 + // Calculate Replication index [R0, R1, ...] based on Partion index + // FIXME: very nasty implementation + template + CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const + { + static_assert(PartitionIndex::size() == NDimP, "wrong!"); + + const auto ps_ys_idx = container_concat(ps_idx, array{0}); + + const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx); + + array rs_idx; + + static_for<0, NDimP, 1>{}([&](auto idim_p) { + constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size(); + + static_for<0, ndim_low, 1>{}([&](auto i) { + constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i]; + constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i]; + + // 0-th rh_major is the replicate dimension + if constexpr(rh_major == 0) + { + constexpr index_t adaptor_hidden_id = + DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor]; + + // fill in + rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id]; + } + }); + }); + + return rs_idx; + } +#endif + + CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() + { + constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_; + constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_; + + return generate_tuple( + [&](auto i) { + constexpr auto span_impl = distributed_spans_impl[i]; + constexpr index_t ndim_span_minor = ndims_spans_minor[i]; + + constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor); + + return detail::make_tile_distributed_span(span); + }, + number{}); + } + + // FIXME: it's hacky to get Y index from Distributed-Index + template + CK_TILE_HOST_DEVICE static constexpr auto + get_y_indices_from_distributed_indices(DistributedIndices) + { + constexpr auto ys_idx_arr = [] { + array ys_idx; + + static_for<0, NDimY, 1>{}([&](auto i) { + constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i]; + constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i]; + + constexpr auto dstr_index = DistributedIndices{}[number{}]; + + ys_idx(i) = dstr_index.impl_[span_minor]; + }); + + return ys_idx; + }(); + + constexpr index_t ndim_y = NDimY; + + return TO_SEQUENCE(ys_idx_arr, ndim_y); + } + + CK_TILE_HOST_DEVICE static constexpr bool is_static() + { + return PsYs2XsAdaptor::is_static() && Ys2DDescriptor::is_static(); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tile_distribution{"); + // + printf("tile_distribution_encoding: "); + print(DstrEncode{}); + printf(", "); + // + printf("ps_ys_to_xs_: "); + print(ps_ys_to_xs_); + printf(", "); + // + printf("ys_to_d_: "); + print(ys_to_d_); + // + printf("}"); + } +}; + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t iend) +{ + array arr{0}; + + for(index_t i = 0; i < iend - ibegin; ++i) + { + arr(i) = ibegin + i; + } + + return arr; +} + +// this returns a constexpr encoding of tile_distribution +template +CK_TILE_HOST_DEVICE constexpr auto + make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) +{ + using RsLengths = typename StaticTileDistributionEncoding_::RsLengths; + using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss; + using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor; + using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor; + using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor; + using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor; + + // FIXME: increase max value if fail + constexpr index_t kMaxNumTransforms = 20; + constexpr index_t kMaxMetaDataSize = 128; + constexpr index_t kMaxNumDim = 10; + + using Name = coord_transform_enum; + using MetaData = meta_data_buffer; + using NumDim = index_t; + using Dims = array; + using Lengths = array; + + // Tile Adaptor + // bottom dims [x0, x1, x2, ...] + // top dims [p0, p1, ..., y0, y1, ...] + constexpr index_t ndim_x = HsLengthss::size(); + + // Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden] + array, ndim_x + 1> rh_major_minor_to_hidden_ids; + array, ndim_x + 1> rh_major_minor_to_hidden_lengths; + + auto trans = array, kMaxNumTransforms>{}; + + index_t num_tran = 0; + index_t hidden_dim_cnt = ndim_x; + + // this is replicate transform + { + constexpr index_t ndim_r_minor = RsLengths::size(); + + constexpr auto r_minor_lengths = RsLengths{}; + + trans(num_tran++) = { + coord_transform_enum::replicate, + MetaData{to_array(r_minor_lengths)}, + NumDim{0}, + Dims{}, + NumDim{ndim_r_minor}, + make_sequential_index(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)}; + + for(index_t i = 0; i < ndim_r_minor; ++i) + { + rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt; + rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i]; + + hidden_dim_cnt++; + } + }; + + // these are Unmerge transforms for X dimesions + static_for<0, ndim_x, 1>{}([&trans, + &num_tran, + &hidden_dim_cnt, + &rh_major_minor_to_hidden_ids, + &rh_major_minor_to_hidden_lengths](auto idim_x) { + // typename HsLengthss::base{}.foo(); + constexpr auto h_minor_lengths = + HsLengthss{}.get(idim_x); // std::tuple_element_t{}; + // constexpr auto h_minor_lengths = impl::getv(HsLengthss{}); + + constexpr index_t ndim_h_minor = h_minor_lengths.size(); + + trans(num_tran++) = { + coord_transform_enum::unmerge, + MetaData{to_array(h_minor_lengths)}, + NumDim{1}, + Dims{idim_x}, + NumDim{ndim_h_minor}, + make_sequential_index(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)}; + + for(index_t i = 0; i < ndim_h_minor; ++i) + { + rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt; + rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i]; + + hidden_dim_cnt++; + } + }); + + // transform: P dimensions + constexpr index_t ndim_p = Ps2RHssMajor::size(); + + Dims hidden_dim_id_ps; + + static_for<0, ndim_p, 1>{}([&](auto iDimP) { + // + index_t hidden_dim_id_p = hidden_dim_cnt++; + + hidden_dim_id_ps(iDimP) = hidden_dim_id_p; + + constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP]; + constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP]; + + static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!"); + + constexpr index_t ndim_low = p2RHsMajor.size(); + + Dims low_dims; + Lengths low_lengths; + + for(index_t i = 0; i < ndim_low; ++i) + { + index_t rh_major = p2RHsMajor[i]; + index_t rh_minor = p2RHsMinor[i]; + low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor]; + low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor]; + } + + trans(num_tran++) = {coord_transform_enum::merge, + MetaData{to_array(low_lengths)}, + NumDim{ndim_low}, + low_dims, + NumDim{1}, + Dims{hidden_dim_id_p}}; + }); + + constexpr index_t ndim_bottom = ndim_x; + + constexpr auto bottom_dim_ids = make_sequential_index(0, ndim_bottom); + + constexpr auto ys_to_rhs_major = Ys2RHsMajor{}; + constexpr auto ys_to_rhs_minor = Ys2RHsMinor{}; + + constexpr index_t ndim_y = Ys2RHsMajor::size(); + constexpr index_t ndim_top = ndim_p + ndim_y; + + auto top_dim_ids = hidden_dim_id_ps; + + { + for(index_t i = 0; i < ndim_y; ++i) + { + index_t rh_major = ys_to_rhs_major[i]; + index_t rh_minor = ys_to_rhs_minor[i]; + top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor]; + } + } + + // + const auto ps_ys_to_xs_adaptor_encoding = + make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top); + + // descriptor: [y0, y1, ...] to [d] + Lengths y_lengths; + index_t d_length = 1; + + for(index_t i = 0; i < ndim_y; ++i) + { + index_t rh_major = ys_to_rhs_major[i]; + index_t rh_minor = ys_to_rhs_minor[i]; + index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor]; + y_lengths(i) = y_length; + d_length *= y_length; + } + + auto tran = make_tuple(coord_transform_enum::unmerge, + MetaData{to_array(y_lengths)}, + NumDim{1}, + Dims{0}, + NumDim{ndim_y}, + make_sequential_index(1, ndim_y + 1)); + + const auto ys_to_d_adaptor_encoding = make_tuple( + make_tuple(tran), 1, Dims{0}, 1, make_sequential_index(1, ndim_y + 1), ndim_y); + + return make_tuple(ps_ys_to_xs_adaptor_encoding, + ys_to_d_adaptor_encoding, + d_length, + rh_major_minor_to_hidden_ids); +} + +// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail +template // tuple, ...> +struct tile_distribution_detail +{ + static constexpr auto rh_major_minor_to_adaptor_hidden_idss_ = + to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{}); +}; + +} // namespace detail + +// this returns a constexpr tile_distribution +template +CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_) +{ + using DstrEncode = remove_cvref_t; + + constexpr auto adaptor_impl = + detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{}); + + constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>(); + constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>(); + constexpr index_t d_length = adaptor_impl.template at<2>(); + constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>(); + + constexpr auto ps_ys_to_xs_adaptor = + CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl); + + constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl); + + constexpr auto ys_to_d_descriptor = + make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length); + + // + constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_; + constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_; + + constexpr auto rh_major_minor_to_hidden_ids = + TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor); + + return tile_distribution< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + detail::tile_distribution_detail>>{ + ps_ys_to_xs_adaptor, ys_to_d_descriptor}; +} + +// this returns a static tile_distribution +template +CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_) +{ + using DstrEncode = remove_cvref_t; + + constexpr auto adaptor_impl = + detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{}); + + constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>(); + constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>(); + constexpr index_t d_length = adaptor_impl.template at<2>(); + constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>(); + + constexpr auto ps_ys_to_xs_adaptor = + CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl); + + constexpr auto ys_to_d_adaptor = + CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl); + + constexpr auto ys_to_d_descriptor = + make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, number{}); + + // + constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_; + constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_; + + constexpr auto rh_major_minor_to_hidden_ids = + TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor); + + return tile_distribution< + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + detail::tile_distribution_detail>>{ + ps_ys_to_xs_adaptor, ys_to_d_descriptor}; +} + +//*********************************************************************************** + +namespace detail { + +template +CK_TILE_HOST_DEVICE auto get_partition_index(Distribution) +{ + // only support warp-tile and block-tile + static_assert(Distribution::NDimP == 1 or Distribution::NDimP == 2, "wrong!"); + + if constexpr(Distribution::NDimP == 1) + { + return array{get_lane_id()}; + } + else if constexpr(Distribution::NDimP == 2) + { + return array{get_warp_id(), get_lane_id()}; + } +} + +template +struct reverse_slice_sequence_impl; + +template +struct reverse_slice_sequence_impl, + sequence, + sequence, + SliceSize> +{ + using old_scan = + reverse_slice_sequence_impl, sequence, sequence, SliceSize>; + + static constexpr auto slice_size = old_scan::remaining_slice_sizes::front().value; + static constexpr auto slice_length = + std::conditional_t, number>::value; + + using dim_lengths = + typename sequence_merge, typename old_scan::dim_lengths>::type; + using dim_slices = + typename sequence_merge, typename old_scan::dim_slices>::type; + using remaining_slice_sizes = typename sequence_merge< + std::conditional_t, sequence>, + typename old_scan::remaining_slice_sizes>::type; + + // the first idx that sliced length not equal to original length + static constexpr index_t _flag = + slice_length != x && remaining_slice_sizes{}.front().value == 1; + static constexpr index_t _split_flag = std::conditional_t, number<0>>::value; + static constexpr index_t _split_idx = + std::conditional_t<_split_flag, number, number<0>>::value; + + static constexpr index_t split_flag = _split_flag || old_scan::split_flag; + static constexpr index_t split_idx = std:: + conditional_t, number<_split_idx>>::value; +}; + +template +struct reverse_slice_sequence_impl, sequence, sequence, SliceSize> +{ + static constexpr auto slice_size = SliceSize; + static constexpr auto slice_length = + std::conditional_t, number>::value; + + using dim_lengths = sequence; + using dim_slices = sequence; + using remaining_slice_sizes = + std::conditional_t, sequence>; + + // the first idx that sliced length not equal to original length + static constexpr index_t _flag = + slice_length != x && remaining_slice_sizes{}.front().value == 1; + static constexpr index_t split_flag = std::conditional_t, number<0>>::value; + static constexpr index_t split_idx = + std::conditional_t, number<0>>::value; +}; + +// clang-format off +// input a sequence(with optional mask), and the SliceSize : size per slice +// output the sequence each slice, and number of slices +// +// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0 +// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2 +// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2 +// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1 +// +// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0 +// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0 +// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0 +// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1 +// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2 +// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2 +// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2 +// +// <4, 2, 1, 4, 2> / 4 -> +// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0 +// +// return tuple, slice_index is at which index will start +// have split slices (right -> left) +// or the first index that sliced length is different from the original length +// clang-format on +template ::type> +constexpr auto reverse_slice_sequence(Seq, + number, + Mask = typename uniform_sequence_gen::type{}) +{ + static_assert(Seq::size() == Mask::size()); + using sliced_type = + reverse_slice_sequence_impl::type, + SliceSize>; + static_assert(sliced_type::remaining_slice_sizes::front().value == 1, + "can not evenly divide this sequence, please check"); + return make_tuple(typename sliced_type::dim_lengths{}, + typename sliced_type::dim_slices{}, + number{}); +} + +// +// slice tensor from x_dim, result in split in y_dim, not p_dim. +// We don't support slice cross p_dim (aka, slice different threads) +// also, sliced along y_dim need be the first dim of current dim. +// Multiply Y dim before sliced dim does not make sense +// +// e.g +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK +// |--> slice along this Y dim, is the first dim of X1, totally 4 slices +// +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK +// |--> slice along this Y dim, the P dim is 1 in the left, so is OK +// totally 16 slices +// +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail +// |--> slice along this P dim, will split threads, not supported +// +// X0 X1 +// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length) +// Y P P Y P Y P Y +// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK +// |--> slice along this Y dim, but this Y sim need to split into 2 +// subdime +// the P dim in the left is 1, means actually not crossing P +// +template +CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( + Distribution, sequence x_slice_begins, sequence x_slice_ends) +{ + // NOTE: this function need to be called under constexpr context, + // due to https://wg21.link/p2280r0 we have to use non-reference type for distribution + using Encoding = decltype(Distribution::get_static_tile_distribution_encoding()); + + static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds)); + + constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins; + + constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum(); + constexpr auto src_y_info = Encoding::detail::get_sorted_y_info(); + constexpr auto src_y_dims = src_y_info[number<0>{}]; + constexpr auto src_y_maps = src_y_info[number<1>{}]; + constexpr auto src_y_prefix_sum = src_y_info[number<2>{}]; + + constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr + { + auto y_slice_sorted_origins = make_zero_multi_index(); + auto y_slice_lengths = Encoding::detail::ys_lengths_; + + // This lambda will modify some value outside, so c++ will not treat return value as + // constexpr + // TODO: ugly + auto new_h_lengths = transform_tuples( + [&](auto h_len, auto id) { + constexpr auto sliced_h = + reverse_slice_sequence(h_len, number{}); + + constexpr auto sliced_h_lens = sliced_h[number<0>{}]; + constexpr auto sliced_h_index = sliced_h[number<2>{}]; + + // update y_slice_lengths + constexpr auto uniformed_h_index = sliced_h_index + number{}; + constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index); + + static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(), + "not sliced at y dim, please check"); + + static_for<0, sliced_h_index + 1, 1>{}([&](auto i) { + y_slice_lengths(src_y_maps[found_y_index - i]) = + sliced_h_lens[sliced_h_index - i]; + }); + // TODO: add validations not across p dim + + // NOTE: this y_origin is for all dims, not only current dim + // will later use pick to select target dim + constexpr auto y_origin = [&]() { + constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len); + auto h_origin_ = make_zero_multi_index(); + h_trans.calculate_lower_index(h_origin_, sequence{}); + + auto y_origin_ = make_zero_multi_index(); + static_for<0, sliced_h_index + 1, 1>{}([&](auto i) { + y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i]; + }); + return y_origin_; + }(); + + constexpr auto y_picks = typename arithmetic_sequence_gen::type{}; + + set_container_subset( + y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks)); + return sliced_h_lens; + }, + typename Encoding::HsLengthss{}, + typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{}); + + auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps); + + return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths); + } + (); + + constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}]; + constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}]; + constexpr auto sliced_y_origins_size = sliced_y_origins_array.size(); + constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}]; + constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size(); + + constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size); + constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size); + + return make_tuple( + make_static_tile_distribution( + tile_distribution_encoding{}), + sliced_y_origins, + sliced_y_lengths); +} + +} // namespace detail +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp new file mode 100644 index 000000000..7b1e95202 --- /dev/null +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -0,0 +1,760 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/algorithm/coordinate_transform.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/multi_index.hpp" +#include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template + typename HsLengthss_, // tuple, ...> + typename Ps2RHssMajor_, // tuple, ...> + typename Ps2RHssMinor_, // tuple, ...> + typename Ys2RHsMajor_, // sequence<...> + typename Ys2RHsMinor_> // sequence<...> +struct tile_distribution_encoding +{ + using RsLengths = remove_cvref_t; + using HsLengthss = remove_cvref_t; + using Ps2RHssMajor = remove_cvref_t; + using Ps2RHssMinor = remove_cvref_t; + using Ys2RHsMajor = remove_cvref_t; + using Ys2RHsMinor = remove_cvref_t; + + static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!"); + static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!"); + + static constexpr index_t NDimX = HsLengthss::size(); + static constexpr index_t NDimP = Ps2RHssMajor::size(); + static constexpr index_t NDimY = Ys2RHsMajor::size(); + static constexpr index_t NDimR = RsLengths::size(); + + // FIXME: move into detail + static constexpr auto rs_lengths_ = RsLengths{}; + static constexpr auto hs_lengthss_ = HsLengthss{}; + static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{}; + static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{}; + static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{}; + static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{}; + + // redundant but useful info + // TODO: really bad code, should be over-hauled + struct detail + { + // ndim_rh_major_, ndim_span_mainor_ + static constexpr index_t ndim_rh_major_ = NDimX + 1; + static constexpr index_t ndim_span_major_ = NDimX; + + // ndims_rhs_minor_[ndim_rh_major_] + static constexpr auto ndims_rhs_minor_ = generate_array( + [](auto i) { + if constexpr(i.value == 0) + { + return rs_lengths_.size(); + } + else + { + return hs_lengthss_[i - number<1>{}].size(); + } + }, + number{}); + + // max_ndim_rh_minor_ + static constexpr index_t max_ndim_rh_minor_ = + container_reduce(ndims_rhs_minor_, maximize{}, 0); + + // rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_] + static constexpr auto rhs_lengthss_ = + to_array_of_array(container_concat(make_tuple(rs_lengths_), hs_lengthss_)); + + // ys_lengths_ + static constexpr auto ys_lengths_ = [] { + array ys_lengths_tmp{-1}; + + for(index_t i = 0; i < NDimY; i++) + { + index_t rh_major = ys_to_rhs_major_[i]; + index_t rh_minor = ys_to_rhs_minor_[i]; + + ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor]; + } + + return ys_lengths_tmp; + }(); + + // rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_] + static constexpr auto rhs_major_minor_to_ys_ = [] { + array, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}}; + + static_for<0, NDimY, 1>{}([&](auto i) { + constexpr index_t rh_major = ys_to_rhs_major_[i]; + constexpr index_t rh_minor = ys_to_rhs_minor_[i]; + + rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i; + }); + + return rhs_major_minor_to_ys_tmp; + }(); + + // ndims_span_minor_[NDimY] + static constexpr auto ndims_span_minor_ = [] { + array ndims_span_minor{0}; + + for(index_t i = 0; i < NDimY; i++) + { + const index_t span_major = ys_to_rhs_major_[i] - 1; + + ndims_span_minor(span_major)++; + } + + return ndims_span_minor; + }(); + + // max_ndim_span_minor_ + static constexpr index_t max_ndim_span_minor_ = + container_reduce(ndims_span_minor_, maximize{}, 0); + + // rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_] + static constexpr auto rhs_major_minor_to_span_minor_ = [] { + array, ndim_rh_major_> rhs_major_minor_to_span_minor{ + {-1}}; + + static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) { + constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major]; + + index_t cnt_ndim_span_minor = 0; + + static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) { + constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor]; + + if(idim_y >= 0) + { + rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor; + + cnt_ndim_span_minor++; + } + }); + }); + + return rhs_major_minor_to_span_minor; + }(); + + // ys_to_span_major_[NDimY] + static constexpr auto ys_to_span_major_ = + generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number{}); + + // ys_to_span_minor_[NDimY] + static constexpr auto ys_to_span_minor_ = generate_array( + [](auto i) { + return rhs_major_minor_to_span_minor_[ys_to_rhs_major_[i]][ys_to_rhs_minor_[i]]; + }, + number{}); + + // distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_] + static constexpr auto distributed_spans_lengthss_ = [] { + array, ndim_span_major_> + distributed_spans_lengthss{{-1}}; + + static_for<0, NDimY, 1>{}([&](auto i) { + const index_t rh_major = ys_to_rhs_major_[i]; + const index_t rh_minor = ys_to_rhs_minor_[i]; + + const index_t h_length = hs_lengthss_[number{}][rh_minor]; + + const index_t span_major = rh_major - 1; + const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor]; + + distributed_spans_lengthss(span_major)(span_minor) = h_length; + }); + + return distributed_spans_lengthss; + }(); + + // ndims_distributed_spans_minor_[ndim_span_major_] + static constexpr auto ndims_distributed_spans_minor_ = [] { + array ndims_distributed_spans_minor{0}; + + static_for<0, NDimY, 1>{}([&](auto i) { + const index_t span_major = ys_to_rhs_major_[i] - 1; + + ndims_distributed_spans_minor(span_major)++; + }); + + return ndims_distributed_spans_minor; + }(); + + // does_p_own_r_[NDimP][NDimR] + static constexpr auto does_p_own_r_ = [] { + if constexpr(NDimR > 0) + { + array, NDimP> does_p_own_r{{false}}; + + static_for<0, NDimP, 1>{}([&](auto idim_p) { + constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low]; + constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low]; + + if constexpr(rh_major == 0) + { + does_p_own_r(idim_p)(rh_minor) = true; + } + }); + }); + + return does_p_own_r; + } + else + { + return array, NDimP>{}; + } + }(); + + // ps_over_rs_derivative_[NDimP][NDimR] + static constexpr auto ps_over_rs_derivative_ = [] { + if constexpr(NDimR > 0) + { + array, NDimP> ps_over_rs_derivative{{0}}; + + static_for<0, NDimP, 1>{}([&](auto idim_p) { + constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size(); + + index_t p_over_rh_derivative = 1; + + static_for{}([&](auto idim_low) { + constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low]; + constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low]; + + constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor]; + + if constexpr(rh_major == 0) + { + ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative; + } + + p_over_rh_derivative *= rh_length; + }); + }); + + return ps_over_rs_derivative; + } + else + { + return array, NDimP>{}; + } + }(); + + // e.g. tuple, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8> + CK_TILE_HOST_DEVICE static constexpr auto get_h_dim_lengths_prefix_sum() + { + // + // e.g. tuple, seq<4, 1, 4, 2, 4>> --> seq<3, 5> + constexpr auto uniformed_h_dim_lengths = generate_sequence_v2( + [&](auto i) { + constexpr index_t size = HsLengthss{}[i].size(); + return number{}; + }, + number{}); + + // <0, len_d0, len_d0+len_d1, ...> + // e.g. seq<3, 5> --> seq<0, 3, 8> + constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths); + + return h_dim_prefix_sum; + } + + CK_TILE_HOST_DEVICE static constexpr auto get_uniformed_idx_y_to_h() + { + constexpr auto all_ys_2_rhss = transform_sequences( + [](auto major, auto minor) constexpr { + // <0, 0, len_d0, len_d0+len_d1, ...> + constexpr auto x_dim_prefix_sum = merge_sequences( + sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum()); + return x_dim_prefix_sum.at(major) + minor; + }, + Ys2RHsMajor{}, + Ys2RHsMinor{}); + + return all_ys_2_rhss; + } + + // return tuple + template + CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq) + { + using sorted_idx = sequence_unique_sort, equal>; + + constexpr auto sorted_dims = typename sorted_idx::type{}; + constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{}; + + constexpr auto sorted_histogram = + histogram_sorted_sequence(sorted_dims, PrefixSumSeq{}); + constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram); + + return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum); + } + + CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info() + { + return get_sorted_info(get_uniformed_idx_y_to_h(), get_h_dim_lengths_prefix_sum()); + } + + CK_TILE_HOST_DEVICE void print() const + { + printf("tile_distribution_encoding::detail{"); + // + printf("ndim_rh_major_: "); + print(ndim_rh_major_); + printf(", "); + // + printf("ndim_span_major_: "); + print(ndim_span_major_); + printf(", "); + // + printf("ndims_rhs_minor_: "); + print(ndims_rhs_minor_); + printf(", "); + // + printf("ndim_rh_major_: "); + print(ndim_rh_major_); + printf(", "); + // + printf("max_ndim_rh_minor_: "); + print(max_ndim_rh_minor_); + printf(", "); + // + printf("rhs_lengthss_: "); + print(rhs_lengthss_); + printf(", "); + // + printf("ys_lengths_: "); + print(ys_lengths_); + printf(", "); + // + printf("rhs_major_minor_to_ys_: "); + print(rhs_major_minor_to_ys_); + printf(", "); + // + printf("ndims_span_minor_: "); + print(ndims_span_minor_); + printf(", "); + // + printf("max_ndim_span_minor_: "); + print(max_ndim_span_minor_); + printf(", "); + // + printf("ys_to_span_major_: "); + print(ys_to_span_major_); + printf(", "); + // + printf("ys_to_span_minor_: "); + print(ys_to_span_minor_); + printf(", "); + // + printf("distributed_spans_lengthss_: "); + print(distributed_spans_lengthss_); + printf(", "); + // + printf("ndims_distributed_spans_minor_: "); + print(ndims_distributed_spans_minor_); + printf(", "); + // + printf("ps_over_rs_derivative_: "); + print(ps_over_rs_derivative_); + // + printf("}"); + } + }; + + CK_TILE_HOST_DEVICE void print() const + { + printf("tile_distribution_encoding{"); + // + printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY); + // + printf("rs_lengths_: "); + print(rs_lengths_); + printf(", "); + // + printf("hs_lengthss_: "); + print(hs_lengthss_); + printf(", "); + // + printf("ps_to_rhss_major_: "); + print(ps_to_rhss_major_); + printf(", "); + // + printf("ps_to_rhss_minor_: "); + print(ps_to_rhss_minor_); + printf(", "); + // + printf("ys_to_rhs_major_: "); + print(ys_to_rhs_major_); + printf(", "); + // + printf("ys_to_rhs_minor_: "); + print(ys_to_rhs_minor_); + printf(", "); + // + printf("detail: "); + print(detail{}); + // + printf("}"); + } +}; + +namespace detail { + +template +CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr) +{ + static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!"); + + constexpr index_t NDimHMajor = OuterDstr::NDimX; + + using RsLengths = + sequence_merge_t; + + constexpr auto hs_lengthss = generate_tuple( + [&](auto i) { + return merge_sequences(typename OuterDstr::HsLengthss{}[i], + typename InnerDstr::HsLengthss{}[i]); + }, + number{}); + + // + constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() { + array rhs_major_2_ndim_outer_rhs_minor_; + + // R dimension + rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size(); + + // Hs dimensions + static_for<0, NDimHMajor, 1>{}([&](auto i) { + rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size(); + }); + + return rhs_major_2_ndim_outer_rhs_minor_; + }(); + + // Ps2RHssMinor + constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple( + [&](auto p) { + constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p]; + constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p]; + + constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size(); + + constexpr auto updated_inner_p_2_rhss_minor = [&]() { + array updated_inner_p_2_rhss_minor_; + + for(index_t i = 0; i < ndim_tmp; i++) + { + index_t rh_major = inner_p_2_rhss_major[i]; + + index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major]; + + updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor; + } + + return updated_inner_p_2_rhss_minor_; + }(); + + return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp); + }, + number{}); + + // Ys2RHsMinor + constexpr auto updated_inner_ys_2_rhs_minor = [&]() { + constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{}; + constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{}; + + constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size(); + + constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() { + array updated_inner_ys_2_rhs_minor__; + + for(index_t i = 0; i < ndim_tmp; i++) + { + index_t rh_major = inner_ys_2_rhs_major[i]; + + index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major]; + + updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor; + } + + return updated_inner_ys_2_rhs_minor__; + }(); + + return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp); + }(); + + // + constexpr auto ps_2_rhss_major = + container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{}); + + constexpr auto ps_2_rhss_minor = + container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor); + + // + constexpr auto ys_2_rhs_major = + merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{}); + + constexpr auto ys_2_rhs_minor = + merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor); + + return tile_distribution_encoding, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{}; +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_reduce_tile_distribution_encoding_impl(InDstr, sequence reduce_dim_xs_in) +{ + constexpr auto I1 = number<1>{}; + + // FIXME: increase if fail + constexpr index_t max_ndim_r_out = 20; + constexpr index_t max_ndim_y_out = 20; + + // + constexpr index_t ndim_p = InDstr::NDimP; + constexpr index_t ndim_x_in = InDstr::NDimX; + constexpr index_t ndim_y_in = InDstr::NDimY; + constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1; + constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs); + constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_; + + // ndims_ps_low + constexpr auto ndims_ps_low = generate_array( + [&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number{}); + + // is_rh_major_in_for_reduce + array is_rh_major_in_for_reduce{false}; + + for(index_t i = 0; i < reduce_dim_xs_in.size(); i++) + { + index_t rh_major = reduce_dim_xs_in[i] + 1; + + is_rh_major_in_for_reduce(rh_major) = true; + } + + // is_y_in_for_reduce + array is_y_in_for_reduce{false}; + + for(index_t i = 0; i < ndim_y_in; i++) + { + index_t rh_major = InDstr::ys_to_rhs_major_[i]; + + if(is_rh_major_in_for_reduce[rh_major]) + { + is_y_in_for_reduce(i) = true; + } + } + + // is_rh_minor_in_for_y_reduce + array, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}}; + + static_for<0, ndim_y_in, 1>{}([&](auto i) { + index_t rh_major = InDstr::ys_to_rhs_major_[i]; + index_t rh_minor = InDstr::ys_to_rhs_minor_[i]; + + if(is_y_in_for_reduce[i]) + { + is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true; + } + }); + + // in2out_rh_major + array in2out_rh_major{-1}; + index_t cnt_ndim_rh_major_out = 0; + + for(index_t i = 0; i < ndim_rh_major_in; i++) + { + if(is_rh_major_in_for_reduce[i]) + { + in2out_rh_major(i) = 0; + } + else + { + in2out_rh_major(i) = cnt_ndim_rh_major_out; + + cnt_ndim_rh_major_out++; + } + } + + // rs_lengths_out, in2out_rh_minor + array rs_lengths_out{-1}; + array, ndim_rh_major_in> in2out_rh_minor{{-1}}; + + // loop over input R dim + for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++) + { + // rs_lengths_out + rs_lengths_out(i) = InDstr::rs_lengths_[i]; + + // in2out_rh_minor + in2out_rh_minor(0)(i) = i; + } + + // loop over input H Dim + index_t cnt_ndim_r_out = InDstr::rs_lengths_.size(); + + static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) { + constexpr auto h_major_in = rh_major_in - I1; + + constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size(); + + if(is_rh_major_in_for_reduce[rh_major_in]) + { + for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++) + { + if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in]) + { + // rs_lengths_out + rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in]; + + // in2out_rh_minor + in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out; + + cnt_ndim_r_out++; + } + } + } + else + { + for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++) + { + // in2out_rh_minor + in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in; + } + } + }); + + // ndim_r_out + const index_t ndim_r_out = cnt_ndim_r_out; + + // ndims_hs_minor_out, hs_lengthss_out + array ndims_hs_minor_out{-1}; + array, ndim_x_out> hs_lengthss_out{{-1}}; + + index_t cnt_ndim_x_out = 0; + + static_for<0, ndim_x_in, 1>{}([&](auto i) { + if(not is_rh_major_in_for_reduce[i + I1]) + { + // ndims_hs_minor_out + ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size(); + + // hs_lengthss_out + static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}( + [&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; }); + + cnt_ndim_x_out++; + } + }); + + // ps_to_rhss_major_out, ps_to_rhss_minor_out + array, ndim_p> ps_to_rhss_major_out{{-1}}; + array, ndim_p> ps_to_rhss_minor_out{{-1}}; + + static_for<0, ndim_p, 1>{}([&](auto idim_p) { + static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) { + index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low]; + index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low]; + + ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in]; + ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in]; + }); + }); + + // ys_to_rhs_major_out, ys_to_rhs_minor_out + array ys_to_rhs_major_out{-1}; + array ys_to_rhs_minor_out{-1}; + + index_t cnt_ndim_y_out = 0; + + static_for<0, ndim_y_in, 1>{}([&](auto i) { + if(not is_y_in_for_reduce[i]) + { + index_t rh_major_in = InDstr::ys_to_rhs_major_[i]; + index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i]; + + ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in]; + ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in]; + + cnt_ndim_y_out++; + } + }); + + // ndim_y_out + const index_t ndim_y_out = cnt_ndim_y_out; + + // + return make_tuple(ndim_x_out, + ndim_p, + ndim_y_out, + ndim_r_out, + ndims_hs_minor_out, + ndims_ps_low, + rs_lengths_out, + hs_lengthss_out, + ps_to_rhss_major_out, + ps_to_rhss_minor_out, + ys_to_rhs_major_out, + ys_to_rhs_minor_out); +} + +template +CK_TILE_HOST_DEVICE constexpr auto +make_reduce_tile_distribution_encoding(InDstr, sequence reduce_dim_xs_in) +{ + constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in); + + constexpr index_t ndim_x = impl.template at<0>(); + constexpr index_t ndim_p = impl.template at<1>(); + constexpr index_t ndim_y = impl.template at<2>(); + constexpr index_t ndim_r = impl.template at<3>(); + constexpr auto ndims_hs_minor = impl.template at<4>(); + constexpr auto ndims_ps_low = impl.template at<5>(); + constexpr auto rs_lengths_impl = impl.template at<6>(); + constexpr auto hs_lengthss_impl = impl.template at<7>(); + constexpr auto ps_to_rhss_major_impl = impl.template at<8>(); + constexpr auto ps_to_rhss_minor_impl = impl.template at<9>(); + constexpr auto ys_to_rhs_major_impl = impl.template at<10>(); + constexpr auto ys_to_rhs_minor_impl = impl.template at<11>(); + + constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r); + constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor); + constexpr auto ps_to_rhss_major = + TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low); + constexpr auto ps_to_rhss_minor = + TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low); + constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y); + constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y); + + return tile_distribution_encoding, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t, + remove_cvref_t>{}; +} + +} // namespace detail +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp new file mode 100644 index 000000000..90ad94b12 --- /dev/null +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -0,0 +1,263 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/null_tensor.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// TODO: support tensors with different distribution +template , null_tensor>>...>>> +CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func, + InOutDstrTensors&... inout_dstr_tensors) +{ + // TODO: make sure all distributed tensors have same lengths and distribution + // static_assert(xxx); + + constexpr index_t thread_buffer_size = + __type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size(); + + static_for<0, thread_buffer_size, 1>{}( + [&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); }); +} + +template >...>>> +CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, + const InTensor&... in_dstr_tensors) +{ + using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...)); + + // TODO: make sure all distributed tensors have same lengths and distribution + // static_assert(xxx); + constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution(); + + constexpr index_t thread_buffer_size = + __type_pack_element<0, InTensor...>::get_thread_buffer_size(); + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + out_dstr_tensor.get_thread_buffer()(i) = + in_element_func(in_dstr_tensors.get_thread_buffer()[i]...); + }); + + return out_dstr_tensor; +} + +template +CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value) +{ + tile_elementwise_inout( + [&value](auto& x) { + x = type_convert>(value); + }, + dstr_tensor); +} + +template +CK_TILE_DEVICE void set_tile(null_tensor&, const T&) +{ +} + +// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with +// sub-dword tensor... +template +CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, number) +{ + constexpr index_t tensor_bytes = + DstrTensors::get_thread_buffer_size() * sizeof(typename DstrTensors::DataType); + if constexpr(v == 0 && tensor_bytes % 4 == 0) + { + using dvec_t = array; + auto& tensor = reinterpret_cast(dstr_tensor.get_thread_buffer()); + for(auto i = 0; i < tensor.size(); i++) + tensor.get(i) = v; + } + else + { + tile_elementwise_inout( + [](auto& x) { x = type_convert(v); }, + dstr_tensor); + } +} + +template +CK_TILE_DEVICE void set_tile(null_tensor&, number) +{ +} + +template +CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) +{ + set_tile(dstr_tensor, 0); +} + +namespace impl { +// TODO: this is ugly +template +CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) +{ +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + static_assert(thread_buffer_size % 4 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and + // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA + // so we prepare an uninitialized variable purposely, and turn off the warning + int dummy_old; + static_for<0, thread_buffer_size_pk, 1>{}([&](auto i) { + uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32( + in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}], + in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}], + dummy_old, + false); // false -> WORD0 + + uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32( + in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}], + in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}], + dummy_old, + false); // false -> WORD0 + + constexpr int32_t m0 = 0x05040100; + using vec_t = array; + + vec_t d = bit_cast(__builtin_amdgcn_perm(y, x, m0)); + out_dstr_tensor.get_thread_buffer().template set_as(number{}, d); + }); +#pragma clang diagnostic pop + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + +#if CK_TILE_USE_SUBDWORD_TILE_CAST +// this function assume either src or dst (or both) date type is under 1 dword +// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) +template +CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors) +{ + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + using i_type = remove_cvref_t; + using o_type = remove_cvref_t; + constexpr index_t i_elem_bytes = sizeof(i_type); + constexpr index_t o_elem_bytes = sizeof(o_type); + static_assert(i_elem_bytes < 4 || o_elem_bytes < 4); + + constexpr index_t bulk_size = + (i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes); + static_assert(bulk_size != 0); + + using o_bulk_type = + std::conditional_t= o_elem_bytes, float, array>; + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + + constexpr index_t iters = thread_buffer_size / bulk_size; + constexpr index_t rems = thread_buffer_size % bulk_size; + + // cast the sequence per-bulk + static_for<0, iters, 1>{}([&](auto i) { + union bulk_wrapper + { + o_bulk_type bulk{}; + o_type data[bulk_size]; + } o_bulk; + + // TODO: should use below function, but somehow will result in spill (same as c-forloop) + static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) { + o_bulk.data[ib.value] = static_cast( + in_dstr_tensors.get_thread_buffer() + .template get_as()[number{}]); + }); + + // TODO: fixme, should use above! + // static_assert(sizeof(i_type) / sizeof(o_type) == 2); + // o_bulk.data[0] = static_cast( + // in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 0>{}]); + // o_bulk.data[1] = static_cast( + // in_dstr_tensors.get_thread_buffer().template get_as()[number<2 * i + 1>{}]); + + out_dstr_tensor.get_thread_buffer().template set_as(i, o_bulk.bulk); + }); + + static_for<0, rems, 1>{}([&](auto r) { + // TODO: introducing local scratch pad? + auto idx = number{}; + out_dstr_tensor.get_thread_buffer().at(idx) = + static_cast(in_dstr_tensors.get_thread_buffer().at(idx)); + }); + + return out_dstr_tensor; +} +#endif +} // namespace impl + +template +CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) +{ + if constexpr((std::is_same_v || + std::is_same_v)&&std::is_same_v && + (SrcTensor::get_thread_buffer_size() % 4 == 0)) + { + return impl::cast_tile_pk_fp8x4(src_tensor); + } +#if CK_TILE_USE_SUBDWORD_TILE_CAST + else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) + { + return impl::cast_tile_opt_subdword(src_tensor); + } +#endif + else + return tile_elementwise_in(type_convert, src_tensor); +} + +// no-op function for null_tensor arguments +template , null_tensor>...>>> +CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...) +{ +} + +// no-op function for null_tensor arguments +template , null_tensor>...>>> +CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...) +{ + return null_tensor{}; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp new file mode 100644 index 000000000..09a4eb1fc --- /dev/null +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -0,0 +1,740 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/arch/utility.hpp" +#include "ck_tile/core/algorithm/space_filling_curve.hpp" +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/tensor/static_distributed_tensor.hpp" +#include "ck_tile/core/tensor/tensor_adaptor.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/core/utility/functional.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct tile_window_with_static_distribution +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using TileDstr = remove_cvref_t; + + using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + + using DataType = remove_cvref_t; + + static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + // TODO: check WindowLengths and StaticTileDistribution are consistent + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + static_assert(TileDstr::is_static(), "wrong!"); + + static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(), + "wrong! inconsistent # of diemsnions"); + + using AdaptorTopIndex = array; + using BottomTensorIndex = array; + + using WindowAdaptorCoord = + decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})); + + using BottomTensorCoord = + decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})); + + struct load_store_traits + { + private: + static constexpr auto get_vector_dim_y_scalar_per_vector() + { + const auto [ys_vector_lengths, ys_vector_strides] = + tile_window_with_static_distribution:: + get_window_adaptor_ys_safe_vector_length_strides(); + + index_t VectorDimY_ = 0; + index_t ScalarPerVector_ = 1; + + for(index_t i = 0; i < NDimY; ++i) + { + if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_) + { + ScalarPerVector_ = ys_vector_lengths[i]; + VectorDimY_ = i; + } + } + + return make_tuple(VectorDimY_, ScalarPerVector_); + } + + public: + static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>(); + static constexpr index_t ScalarPerVector = + get_vector_dim_y_scalar_per_vector().template at<1>(); + + // using vector_type_t = vector_type_maker_t; + // using vector_t = typename vector_type_t::type; + using vector_t = thread_buffer; + + private: + static constexpr auto scalars_per_access_ = [] { + constexpr auto scalars_per_access_arr = generate_array( + [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number{}); + + /// TODO: add non-automatic storage argument support to macro TO_SEQUENCE() + constexpr auto NDimY_ = NDimY; + + return TO_SEQUENCE(scalars_per_access_arr, NDimY_); + }(); + + static constexpr auto get_space_filling_curve() + { + constexpr auto tile_dstr = TileDstr{}; + + constexpr auto thread_tensor_lengths_ys = + to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths()); + + // FIXME: need logic to judge dim access order + using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type; + + return space_filling_curve{}; + } + + public: + using SFC_Ys = decltype(get_space_filling_curve()); + + static constexpr index_t NumAccess = SFC_Ys::get_num_of_access(); + + static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0"); + static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord"); + }; + + static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord; + + CK_TILE_DEVICE constexpr tile_window_with_static_distribution() = default; + + CK_TILE_DEVICE constexpr tile_window_with_static_distribution( + const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin, + const TileDstr& tile_distribution) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin}, + tile_dstr_{tile_distribution}, + pre_computed_coords_{} + { +#if 0 // debug + // TODO: this use more register for FA, but less register for GEMM + // need investigation + // only support warp-tile and block-tile + static_assert(NDimP == 1 or NDimP == 2, "wrong!"); + + WindowAdaptorCoord window_adaptor_thread_coord_tmp; + + if constexpr(NDimP == 1) + { + window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0}); + } + else if constexpr(NDimP == 2) + { + window_adaptor_thread_coord_tmp = + make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(), + AdaptorTopIndex{get_warp_id(), get_lane_id(), 0}); + } +#else + // TODO: this use less register for FA, but more register for GEMM + // need investigation + const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(detail::get_partition_index(tile_distribution), + array{0})); +#endif + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = + window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); + + const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( + bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + using Traits = load_store_traits; + using SFC_Ys = typename Traits::SFC_Ys; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE static constexpr bool has_static_tile_distribution() + { + return TileDstr::is_static(); + } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + // move thread's window adaptor coordinate and bottom tensor coordinate + // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset] + CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate( + WindowAdaptorCoord& window_adaptor_thread_coord, + BottomTensorCoord& bottom_tensor_thread_coord, + const AdaptorTopIndex& idx_diff_adaptor_top) const + { + array idx_diff_adaptor_bottom; + + move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(), + window_adaptor_thread_coord, + idx_diff_adaptor_top, + idx_diff_adaptor_bottom); + + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_coord, + idx_diff_adaptor_bottom); + } + + // return vector dimension among [y0, y1, ...] + CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides() + { + // bottom tensor top dimension vector lengths and strides + const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] = + BottomTensorDesc::get_top_dimension_safe_vector_length_strides(); + + // window vector lengths/strides + const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths; + const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides; + + // window adaptor [p0, p1, ..., y0, y1, ...] + array window_adaptor_vector_lengths{ + -1}; + array window_adaptor_vector_strides{ + -1}; + + constexpr auto window_adaptor_bottom_dims = + WindowAdaptor::get_bottom_dimension_hidden_ids(); + + set_container_subset(window_adaptor_vector_lengths, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_lengths); + set_container_subset(window_adaptor_vector_strides, + window_adaptor_bottom_dims, + window_adaptor_bottom_dim_vector_strides); + + const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] = + WindowAdaptor{}.get_top_dimension_safe_vector_length_strides( + window_adaptor_vector_lengths, window_adaptor_vector_strides); + + // [y0, y1, ...] + constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; + + return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims), + get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims)); + } + + CK_TILE_DEVICE constexpr auto get_num_access() const { return load_store_traits::NumAccess; } + + template + CK_TILE_DEVICE auto load(bool_constant = {}) const + { + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from bottom tensor + const vector_t vec_value = + get_bottom_tensor_view().template get_vectorized_elements( + bottom_tensor_thread_coord, bool_constant{}); +#if 1 + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + dst_tensor.get_thread_buffer().template at() = + vec_value.template get_as()[j]; + }); +#else + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + + return dst_tensor; + } + + template + CK_TILE_DEVICE void load_raw(DstTile& dst_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + static constexpr index_t YElementSize = + TileDstr{}.get_ys_to_d_descriptor().get_element_space_size(); + static_assert(YElementSize % Traits::ScalarPerVector == 0); + using vectorized_tbuf = array; + // StaticBuffer; + + constexpr auto tile_dstr = TileDstr{}; + + auto& dst_vec_tbuf = reinterpret_cast(dst_tensor.get_thread_buffer()); + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + + get_bottom_tensor_view().template get_vectorized_elements_raw( + dst_vec_tbuf.template at(), + bottom_tensor_thread_coord, + bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // TODO: currently async load only implemented in inline asm + template + CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + // using LdsTensorView = typename LdsTileWindow::BottomTensorView; + using LdsDataType = typename LdsTileWindow::DataType; + // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc; + + // issues * warps * lanes + static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded + + const index_t size_per_buf = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType); + + const index_t size_per_wave = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<0>{}, number<1>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t size_per_issue = + lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset( + make_tuple(number<1>{}, number<0>{}, number<0>{})) * + sizeof(LdsDataType) - + size_per_buf; + + const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id(); + m0_set_with_memory(m0_init_value); // This should be wave independent + + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + // TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // read from bottom tensor + get_bottom_tensor_view().template async_get_vectorized_elements( + smem, bottom_tensor_thread_coord); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + m0_inc_with_memory(size_per_issue); + } + }); + }); + } + + template + CK_TILE_DEVICE void store(const static_distributed_tensor& dstr_tensor, + bool_constant = {}) const + { + using Traits = load_store_traits; + + // using vector_type_t = typename Traits::vector_type_t; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + // vector_type_t vec; + vector_t vec_value; + + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // const vector_t vec_value = vec.template get_as().template at<0>(); + + // write into bottom tensor + get_bottom_tensor_view().template set_vectorized_elements( + bottom_tensor_thread_coord, vec_value, bool_constant{}); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + CK_TILE_DEVICE void + store_raw(const static_distributed_tensor& dstr_tensor) const + { + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = TileDstr{}; + static constexpr bool oob_conditional_check = true; + + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from distributed tensor + vector_t vec_value; + static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { + constexpr auto idx_ys = generate_array( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys); + vec_value.template get_as()(j) = + dstr_tensor.get_thread_buffer().template at(); + }); + + // write into bottom tensor + get_bottom_tensor_view() + .template set_vectorized_elements_raw( + bottom_tensor_thread_coord, vec_value); + + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = + container_concat(array{0}, idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + + // move thread's botom tensor coordiante + // [x0', x1', ... ] ==> [offset] + // also move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) + { + window_origin_ += step; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_coords_(iCoord)(I1), + step); + }); + } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; + + // Tile tensor distribution, which contains: + // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] + // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d] + TileDstr tile_dstr_; + + // this contains: + // per-thread coordinate for window adaptor + // per-thread coordinate for bottom tensor + array, NumCoord> pre_computed_coords_; +}; + +// TODO: use strategy +template +CK_TILE_DEVICE constexpr auto +make_tile_window(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin, + const StaticTileDistribution_& tile_distribution, + number = {}) +{ + return tile_window_with_static_distribution, + remove_cvref_t, + remove_cvref_t, + NumCoord>{ + tensor_view, window_lengths, origin, tile_distribution}; +} + +template +CK_TILE_DEVICE void move_tile_window( + tile_window_with_static_distribution& window, + const typename tile_window_with_static_distribution::BottomTensorIndex& step) +{ + window.move(step); +} + +template +struct tile_window_with_static_lengths +{ + using BottomTensorView = remove_reference_t; + using WindowLengths = remove_cvref_t; + using BottomTensorDesc = typename BottomTensorView::TensorDesc; + using DataType = typename BottomTensorView::DataType; + + static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); + + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + using BottomTensorIndex = array; + + CK_TILE_DEVICE constexpr tile_window_with_static_lengths() = default; + + CK_TILE_DEVICE constexpr tile_window_with_static_lengths( + const BottomTensorView& bottom_tensor_view, + const WindowLengths& window_lengths, + const BottomTensorIndex& window_origin) + : bottom_tensor_view_{bottom_tensor_view}, + window_lengths_{window_lengths}, + window_origin_{window_origin} + { + } + + CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } + + CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; } + + CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; } + + CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; } + + // move window-origin + CK_TILE_DEVICE void move(const BottomTensorIndex& step) { window_origin_ += step; } + + // this is the bottom tensor view + // [x0', x1', ...] ==> [offset] + BottomTensorView bottom_tensor_view_; + + // + WindowLengths window_lengths_; + + // origin ([x0', x1', ...]) of window on bottom tensor + BottomTensorIndex window_origin_; +}; + +template +CK_TILE_DEVICE constexpr auto +make_tile_window(const TensorView_& tensor_view, + const WindowLengths_& window_lengths, + const multi_index& origin) +{ + static_assert(ck_tile::is_known_at_compile_time::value, + "wrong! lengths should be static"); + + return tile_window_with_static_lengths, + remove_cvref_t>{ + tensor_view, window_lengths, origin}; +} + +template +CK_TILE_DEVICE void move_tile_window( + tile_window_with_static_lengths& window, + const typename tile_window_with_static_lengths::BottomTensorIndex& + step) +{ + window.move(step); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/bit_cast.hpp b/include/ck_tile/core/utility/bit_cast.hpp new file mode 100644 index 000000000..2cb91b7d4 --- /dev/null +++ b/include/ck_tile/core/utility/bit_cast.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X& x) +{ + static_assert(__has_builtin(__builtin_bit_cast), ""); + static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type"); + + return __builtin_bit_cast(Y, x); +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/functional.hpp b/include/ck_tile/core/utility/functional.hpp new file mode 100644 index 000000000..2cdce9406 --- /dev/null +++ b/include/ck_tile/core/utility/functional.hpp @@ -0,0 +1,208 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/integer.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/container/sequence.hpp" +#include +#include + +namespace ck_tile { + +namespace detail { + +struct swallow +{ + template + CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...) + { + } +}; + +template +struct static_for_impl; + +template +struct static_for_impl> +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + swallow{(f(number{}), 0)...}; + } +}; + +} // namespace detail + +// F signature: F(number) +template +struct static_for +{ + CK_TILE_HOST_DEVICE constexpr static_for() + { + static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), + "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && " + "NBegin >= NEnd)"); + } + + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + detail::static_for_impl::type>{}( + f); + } +}; + +struct identity +{ + template + CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept + { + return std::forward(arg); + } +}; + +namespace detail { + +// RemainLengths: sequence<...> +// Orders: sequence<...> +template +struct static_ford_impl +{ + CK_TILE_HOST_DEVICE constexpr static_ford_impl() + { + static_assert(RemainLengths::size() > 0, "wrong! should not get here"); + } + + // F signature: F(sequence<...>) + // CurrentOrderedId: sequence<...> + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const + { + static_for<0, RemainLengths::front(), 1>{}([=](auto I) { + static_ford_impl{}( + f, CurrentOrderedId::push_back(I)); + }); + } +}; + +template +struct static_ford_impl, Orders> +{ + // F signature: F(sequence<...>) + // OrderedId: sequence<...> + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const + { + // retrive unordered Id + f(OrderedId::reorder_old_to_new(Orders{})); + } +}; + +} // namespace detail + +// Lengths is sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is sequence<...>, it is the order of dimension in which static_ford +// will loop over each +// dimension +template ::type> +struct static_ford +{ + CK_TILE_HOST_DEVICE constexpr static_ford() + { + static_assert(Lengths::size() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size"); + } + + // F signature: F(sequence<...> multi_id) + // multi_id is the unordered multi-index + template + CK_TILE_HOST_DEVICE constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{}); + detail::static_ford_impl{}(f, sequence<>{}); + } +}; + +namespace detail { + +template +struct unpack_impl; + +template +struct unpack_impl> +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const + { +#if 0 + return std::forward(f)(std::forward(x).at(number{})...); +#else + return std::forward(f)(std::forward(x).template at()...); +#endif + } +}; + +template +struct unpack2_impl; + +// TODO: remove this, after properly implementing unpack that takes any number of containers +template +struct unpack2_impl, sequence> +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const + { +#if 0 + return std::forward(f)(std::forward(x).at(number{})..., + std::forward(y).at(number{})...); +#else + return std::forward(f)(std::forward(x).template at()..., + std::forward(y).template at()...); +#endif + } +}; + +} // namespace detail + +template +CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x) +{ + using X_ = remove_reference_t; + return detail::unpack_impl::type>{}( + std::forward(f), std::forward(x)); +} + +// TODO: properly implement unpack that takes any number of containers +template +CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y) +{ + using X_ = remove_reference_t; + using Y_ = remove_reference_t; + return detail::unpack2_impl::type, + typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}( + std::forward(f), std::forward(x), std::forward(y)); +} + +// z = predicate ? x : y +template +constexpr auto conditional_expr(X&& x, Y&& y) +{ + if constexpr(predicate) + { + return std::forward(x); + } + else + { + return std::forward(y); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/ignore.hpp b/include/ck_tile/core/utility/ignore.hpp new file mode 100644 index 000000000..eead91495 --- /dev/null +++ b/include/ck_tile/core/utility/ignore.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// https://en.cppreference.com/w/cpp/utility/tuple/ignore + +namespace ck_tile { + +namespace detail { +struct ignore_t +{ + template + constexpr void operator=(T&&) const noexcept + { + } +}; +} // namespace detail + +inline constexpr detail::ignore_t ignore; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/magic_div.hpp b/include/ck_tile/core/utility/magic_div.hpp new file mode 100644 index 000000000..09038ba29 --- /dev/null +++ b/include/ck_tile/core/utility/magic_div.hpp @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/tuple.hpp" +#include "ck_tile/core/numeric/integral_constant.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include + +namespace ck_tile { + +// magic number division +// Caution: +// 1. For uint32_t as dividend: magic number division implementation being used would produce +// correct result if the dividend is uint32_t and its value is within 31-bit value range. +// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been +// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number +// division implementation for uint32_t is then used. Therefore, dividend value need to be +// non-negative. +// TODO: +// 1. Implement magic number divison for int32_t +// 2. Implement magic number divison for unit32_t with 32-bit value range +struct magic_division32_bit_range +{ + // uint32_t + CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor) + { + // WARNING: magic division is only valid for division inside this range. + // assert(divisor >= 1 && divisor <= INT32_MAX) + + uint32_t shift_u32 = 0; + + while((1U << shift_u32) < divisor) + { + shift_u32++; + }; + + uint64_t tmp_u64 = ((1UL << shift_u32) - divisor) << 32; + uint32_t multiplier_u32 = tmp_u64 / divisor + 1; + + return make_tuple(multiplier_u32, shift_u32); + } + + template > + CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant) + { + constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor}); + + constexpr uint32_t multiplier = tmp[number<0>{}]; + constexpr uint32_t shift = tmp[number<1>{}]; + + return make_tuple(constant{}, constant{}); + } + + // magic division for uint32_t + CK_TILE_DEVICE static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = __umulhi(dividend, multiplier); + return (tmp + dividend) >> shift; + } + + CK_TILE_HOST static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (static_cast(dividend) * multiplier) >> 32; + return (tmp + dividend) >> shift; + } + + // magic division for int32_t + // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be + // non-negative for result to be correct + // TODO: figure out how to do magic number divison for int32_t as dividended + CK_TILE_DEVICE static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = __umulhi(dividend_u32, multiplier); + return (tmp + dividend_u32) >> shift; + } + + CK_TILE_HOST static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = (static_cast(dividend_u32) * multiplier) >> 32; + return (tmp + dividend_u32) >> shift; + } +}; + +// magic number division +// This version on works for divisor and dividended between [0, 1 << 16] +struct magic_division16_bit_range +{ + // uint32_t + CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(uint32_t divisor) + { + // WARNING: magic division is only valid for division inside this range. + // assert(divisor >= 1 && divisor <= (1U << 16)); + + uint32_t shift_u32 = 0; + + while((1U << shift_u32) < divisor) + { + shift_u32++; + }; + + uint32_t one = 1; + uint32_t multiplier_u32 = ((one << 16) * ((one << shift_u32) - divisor)) / divisor + 1; + + return make_tuple(multiplier_u32, shift_u32); + } + + // integral_constant + template + CK_TILE_HOST_DEVICE static constexpr auto calculate_magic_numbers(constant) + { + constexpr auto tmp = calculate_magic_numbers(uint32_t{Divisor}); + + constexpr uint32_t multiplier = tmp[number<0>{}]; + constexpr uint32_t shift = tmp[number<1>{}]; + + return make_tuple(constant{}, constant{}); + } + + // magic division for uint32_t + CK_TILE_DEVICE static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (dividend * multiplier) >> 16; + return (tmp + dividend) >> shift; + } + + CK_TILE_HOST static constexpr uint32_t + do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift) + { + uint32_t tmp = (dividend * multiplier) >> 16; + return (tmp + dividend) >> shift; + } + + // magic division for int32_t + // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be + // non-negative for result to be correct + // TODO: figure out how to do magic number divison for int32_t as dividended + CK_TILE_DEVICE static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = (dividend_u32 * multiplier) >> 16; + return (tmp + dividend_u32) >> shift; + } + + CK_TILE_HOST static constexpr int32_t + do_magic_division(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t dividend_u32 = bit_cast(dividend_i32); + uint32_t tmp = (dividend_u32 * multiplier) >> 16; + return (tmp + dividend_u32) >> shift; + } +}; + +// use 32bit version +using magic_division = magic_division32_bit_range; + +struct mdiv +{ + // 1 dword -> 3 dword storage + uint32_t divisor; + uint32_t multiplier; + uint32_t shift; // TODO: 8 bit is enough + + // prefer construct on host + CK_TILE_HOST_DEVICE mdiv(uint32_t divisor_) : divisor(divisor_) + { + auto tmp = magic_division::calculate_magic_numbers(divisor_); + + multiplier = tmp[number<0>{}]; + shift = tmp[number<1>{}]; + } + + CK_TILE_HOST_DEVICE mdiv() : divisor(0), multiplier(0), shift(0) {} + + CK_TILE_HOST_DEVICE void update(uint32_t divisor_) + { + divisor = divisor_; + auto tmp = magic_division::calculate_magic_numbers(divisor_); + + multiplier = tmp[number<0>{}]; + shift = tmp[number<1>{}]; + } + + CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const + { + return magic_division::do_magic_division(dividend_, multiplier, shift); + } + + CK_TILE_HOST_DEVICE void + divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const + { + quotient_ = div(dividend_); + remainder_ = dividend_ - (quotient_ * divisor); + } + + CK_TILE_HOST_DEVICE uint32_t get() const { return divisor; } +}; + +struct mdiv2 +{ + // 1 dword -> 2 dword storage, divisor need compute from runtime + uint32_t multiplier; + uint32_t shift; // TODO: 8 bit is enough + + // prefer construct on host + CK_TILE_HOST_DEVICE mdiv2(uint32_t divisor_) + { + auto tmp = magic_division::calculate_magic_numbers(divisor_); + + multiplier = tmp[number<0>{}]; + shift = tmp[number<1>{}]; + } + + CK_TILE_HOST_DEVICE mdiv2() : multiplier(0), shift(0) {} + + CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const + { + return magic_division::do_magic_division(dividend_, multiplier, shift); + } + + CK_TILE_HOST_DEVICE void + divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const + { + quotient_ = div(dividend_); + remainder_ = dividend_ - (quotient_ * divisor_); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/random.hpp b/include/ck_tile/core/utility/random.hpp new file mode 100644 index 000000000..f7fbfad4d --- /dev/null +++ b/include/ck_tile/core/utility/random.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include +#include +#include + +namespace ck_tile { + +// return 0 if data is not fp16 or fp32 +template +struct prand_generator_t +{ + CK_TILE_HOST_DEVICE uint32_t operator()(int, T, uint32_t = seed_) { return 0; } +}; + +// version for fp32 +template +struct prand_generator_t +{ + CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_) + { + uint32_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits ^= x >> 16; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is + // very large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; + } +}; + +// version for fp16 +template +struct prand_generator_t +{ + CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_) + { + uint16_t x = *(reinterpret_cast(&val)); + uint32_t drop_bits = uint32_t(x) & 0xFFFFu; + drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); + drop_bits *= 0x7000149; + // NOTE: If id is in 64 bit, we are only using lower 32 bit. + // So, it can have an effect of using same id for multiple elements when the id is + // very large! + uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); + return rng; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/to_sequence.hpp b/include/ck_tile/core/utility/to_sequence.hpp new file mode 100644 index 000000000..2276ab68b --- /dev/null +++ b/include/ck_tile/core/utility/to_sequence.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core/container/sequence.hpp" +// TODO: use c++20 nontype template with struct to implement this + +#if 1 +// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode +#define TO_SEQUENCE(a, n) \ + _Pragma("clang diagnostic push") _Pragma( \ + "clang diagnostic ignored \"-Wc++20-extensions\"")[a]( \ + ck_tile::sequence) \ + { \ + return ck_tile::sequence{})...>{}; \ + } \ + (ck_tile::make_index_sequence{}); \ + _Pragma("clang diagnostic pop") + +#else +// Macro function +// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2) +#define TO_SEQUENCE(a, n) \ + [a, n] { \ + static_assert(a.size() >= n, "wrong! out of bound"); \ + static_assert(n <= 10, "not implemented"); \ + if constexpr(n == 0) \ + { \ + return ck_tile::sequence<>{}; \ + } \ + else if constexpr(n == 1) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 2) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 3) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 4) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 5) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 6) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 7) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 8) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 9) \ + { \ + return ck_tile::sequence{}; \ + } \ + else if constexpr(n == 10) \ + { \ + return ck_tile:: \ + sequence{}; \ + } \ + }() +#endif diff --git a/include/ck_tile/core/utility/transpose_vectors.hpp b/include/ck_tile/core/utility/transpose_vectors.hpp new file mode 100644 index 000000000..a164c3f94 --- /dev/null +++ b/include/ck_tile/core/utility/transpose_vectors.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/container/array.hpp" +#include "ck_tile/core/container/thread_buffer.hpp" +#include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/functional.hpp" + +namespace ck_tile { + +// S: scalar type (or it can be non-scalar type) +// NX: # of vector before transpose +// NY: # of vector after transpose +// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data +template +struct transpose_vectors +{ + static constexpr index_t s_per_x = NY; + static constexpr index_t s_per_y = NX; + + using S = remove_cvref_t; + + using VX = array; + using VY = array; + + CK_TILE_DEVICE void operator()(const thread_buffer& vx_tuple, + thread_buffer& vy_tuple) + { + constexpr auto I1 = number<1>{}; + constexpr auto I2 = number<2>{}; + constexpr auto I3 = number<3>{}; + constexpr auto I4 = number<4>{}; + + if constexpr(sizeof(S) == 2) + { + static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!"); + + using S2 = array; // typename array::type; + + // loop over 2x2 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 2>{}([&](auto iy) { + static_for<0, NX, 2>{}([&](auto ix) { + // 2 16bitx2 data from vx_tuple to be transposed + const int32_t x_s2_0 = + bit_cast(vx_tuple[ix].template get_as()[iy / I2]); + const int32_t x_s2_1 = + bit_cast(vx_tuple[ix + I1].template get_as()[iy / I2]); + + constexpr int32_t m0 = 0x05040100; + constexpr int32_t m1 = 0x07060302; + + // transpose 2x2 16bit + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + const int32_t y_s2_0 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m0); + const int32_t y_s2_1 = __builtin_amdgcn_perm(x_s2_1, x_s2_0, m1); + + // 2 16bitx2 data after transposed + vy_tuple(iy).template get_as()(ix / I2) = bit_cast(y_s2_0); + vy_tuple(iy + I1).template get_as()(ix / I2) = bit_cast(y_s2_1); + }); + }); + } + else if constexpr(sizeof(S) == 1) + { + static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!"); + + using S4 = array; // typename array::type; + + // loop over 4x4 tile and transpose data from vx_tuple into vy_tuple + static_for<0, NY, 4>{}([&](auto iy) { + static_for<0, NX, 4>{}([&](auto ix) { + // 4 int8x4 data from vx_tuple + const int32_t x_s4_0 = + bit_cast(vx_tuple[ix].template get_as()[iy / I4]); + const int32_t x_s4_1 = + bit_cast(vx_tuple[ix + I1].template get_as()[iy / I4]); + const int32_t x_s4_2 = + bit_cast(vx_tuple[ix + I2].template get_as()[iy / I4]); + const int32_t x_s4_3 = + bit_cast(vx_tuple[ix + I3].template get_as()[iy / I4]); + + // transpose + int32_t t_s4_0, t_s4_1; + int32_t y_s4_0, y_s4_1, y_s4_2, y_s4_3; + + constexpr int32_t m0 = 0x05010400; + constexpr int32_t m1 = 0x05040100; + constexpr int32_t m2 = 0x07060302; + constexpr int32_t m3 = 0x07030602; + + // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488 + // -- -- -- -- -- -- -- -- - - - - + // index 7 6 5 4 3 2 1 0 33 77 44 88 + // index is reversed because of little endianness (least significant bits first) + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m0); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m0); + y_s4_0 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_1 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + t_s4_0 = __builtin_amdgcn_perm(x_s4_1, x_s4_0, m3); + t_s4_1 = __builtin_amdgcn_perm(x_s4_3, x_s4_2, m3); + y_s4_2 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m1); + y_s4_3 = __builtin_amdgcn_perm(t_s4_1, t_s4_0, m2); + + // 4 int8x4 data from vy_tuple + vy_tuple(iy).template get_as()(ix / I4) = bit_cast(y_s4_0); + vy_tuple(iy + I1).template get_as()(ix / I4) = bit_cast(y_s4_1); + vy_tuple(iy + I2).template get_as()(ix / I4) = bit_cast(y_s4_2); + vy_tuple(iy + I3).template get_as()(ix / I4) = bit_cast(y_s4_3); + }); + }); + } + else + { + static_assert(false, "not implemented"); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp new file mode 100644 index 000000000..f5dffda86 --- /dev/null +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include +#include + +namespace ck_tile { + +// remove_cvref_t +template +using remove_reference_t = typename std::remove_reference::type; + +template +using remove_cv_t = typename std::remove_cv::type; + +template +using remove_cvref_t = remove_cv_t>; + +template +using remove_pointer_t = typename std::remove_pointer::type; + +namespace detail { +template class Op, class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; +} // namespace detail + +struct nonesuch +{ + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template